From c6c92ae5b25b49e62b06f8867f6a7b0046f04428 Mon Sep 17 00:00:00 2001 From: mxmrlv Date: Sun, 27 Nov 2016 13:20:46 +0200 Subject: [PATCH] ARIA-30 SQL based storage implementation --- aria/__init__.py | 43 +- aria/orchestrator/__init__.py | 4 +- aria/orchestrator/context/common.py | 29 +- aria/orchestrator/context/exceptions.py | 4 +- aria/orchestrator/context/operation.py | 27 +- aria/orchestrator/context/toolbelt.py | 20 +- aria/orchestrator/context/workflow.py | 51 +- aria/orchestrator/exceptions.py | 7 +- aria/orchestrator/workflows/api/task.py | 10 +- aria/orchestrator/workflows/builtin/heal.py | 25 +- .../orchestrator/workflows/builtin/install.py | 7 +- .../workflows/builtin/uninstall.py | 7 +- .../workflows/builtin/workflows.py | 13 +- aria/orchestrator/workflows/core/engine.py | 6 +- aria/orchestrator/workflows/core/task.py | 38 +- aria/storage/__init__.py | 372 +----- aria/storage/api.py | 182 +++ aria/storage/core.py | 125 ++ aria/storage/drivers.py | 416 ------ aria/storage/exceptions.py | 4 +- aria/storage/filesystem_rapi.py | 150 +++ aria/storage/models.py | 702 +++++----- aria/storage/sql_mapi.py | 382 ++++++ aria/storage/structures.py | 399 +++--- aria/utils/application.py | 14 +- requirements.txt | 1 + tests/mock/context.py | 50 +- tests/mock/models.py | 102 +- tests/orchestrator/context/test_operation.py | 80 +- tests/orchestrator/context/test_toolbelt.py | 92 +- tests/orchestrator/context/test_workflow.py | 37 +- tests/orchestrator/workflows/api/test_task.py | 76 +- .../workflows/builtin/__init__.py | 35 +- .../builtin/test_execute_operation.py | 17 +- .../workflows/builtin/test_heal.py | 23 +- .../workflows/builtin/test_install.py | 16 +- .../workflows/builtin/test_uninstall.py | 13 +- .../workflows/core/test_engine.py | 47 +- .../orchestrator/workflows/core/test_task.py | 37 +- .../test_task_graph_into_exececution_graph.py | 15 +- tests/requirements.txt | 2 +- tests/storage/__init__.py | 75 +- tests/storage/test_drivers.py | 135 -- tests/storage/test_field.py | 124 -- tests/storage/test_model_storage.py | 134 +- tests/storage/test_models.py | 1143 ++++++++++++----- tests/storage/test_models_api.py | 70 - tests/storage/test_resource_storage.py | 62 +- 48 files changed, 2854 insertions(+), 2569 deletions(-) create mode 100644 aria/storage/api.py create mode 100644 aria/storage/core.py delete mode 100644 aria/storage/drivers.py create mode 100644 aria/storage/filesystem_rapi.py create mode 100644 aria/storage/sql_mapi.py delete mode 100644 tests/storage/test_drivers.py delete mode 100644 tests/storage/test_field.py delete mode 100644 tests/storage/test_models_api.py diff --git a/aria/__init__.py b/aria/__init__.py index 3f81f989..b000397f 100644 --- a/aria/__init__.py +++ b/aria/__init__.py @@ -23,7 +23,6 @@ from .VERSION import version as __version__ from .orchestrator.decorators import workflow, operation -from .storage import ModelStorage, ResourceStorage, models, ModelDriver, ResourceDriver from . import ( utils, parser, @@ -37,7 +36,6 @@ 'operation', ) -_model_storage = {} _resource_storage = {} @@ -58,37 +56,38 @@ def install_aria_extensions(): del sys.modules[module_name] -def application_model_storage(driver): +def application_model_storage(api, api_kwargs=None): """ Initiate model storage for the supplied storage driver """ + models = [ + storage.models.Plugin, + storage.models.ProviderContext, - assert isinstance(driver, ModelDriver) - if driver not in _model_storage: - _model_storage[driver] = ModelStorage( - driver, model_classes=[ - models.Node, - models.NodeInstance, - models.Plugin, - models.Blueprint, - models.Snapshot, - models.Deployment, - models.DeploymentUpdate, - models.DeploymentModification, - models.Execution, - models.ProviderContext, - models.Task, - ]) - return _model_storage[driver] + storage.models.Blueprint, + storage.models.Deployment, + storage.models.DeploymentUpdate, + storage.models.DeploymentUpdateStep, + storage.models.DeploymentModification, + + storage.models.Node, + storage.models.NodeInstance, + storage.models.Relationship, + storage.models.RelationshipInstance, + + storage.models.Execution, + storage.models.Task, + ] + # if api not in _model_storage: + return storage.ModelStorage(api, items=models, api_kwargs=api_kwargs or {}) def application_resource_storage(driver): """ Initiate resource storage for the supplied storage driver """ - assert isinstance(driver, ResourceDriver) if driver not in _resource_storage: - _resource_storage[driver] = ResourceStorage( + _resource_storage[driver] = storage.ResourceStorage( driver, resources=[ 'blueprint', diff --git a/aria/orchestrator/__init__.py b/aria/orchestrator/__init__.py index a5aeec7c..90d64422 100644 --- a/aria/orchestrator/__init__.py +++ b/aria/orchestrator/__init__.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +Aria orchestrator +""" from .decorators import workflow, operation from . import ( diff --git a/aria/orchestrator/context/common.py b/aria/orchestrator/context/common.py index f2bf83bc..14efd9d1 100644 --- a/aria/orchestrator/context/common.py +++ b/aria/orchestrator/context/common.py @@ -32,8 +32,7 @@ def __init__( model_storage, resource_storage, deployment_id, - workflow_id, - execution_id=None, + workflow_name, task_max_attempts=1, task_retry_interval=0, task_ignore_failure=False, @@ -44,8 +43,7 @@ def __init__( self._model = model_storage self._resource = resource_storage self._deployment_id = deployment_id - self._workflow_id = workflow_id - self._execution_id = execution_id or str(uuid4()) + self._workflow_name = workflow_name self._task_max_attempts = task_max_attempts self._task_retry_interval = task_retry_interval self._task_ignore_failure = task_ignore_failure @@ -54,8 +52,7 @@ def __repr__(self): return ( '{name}(name={self.name}, ' 'deployment_id={self._deployment_id}, ' - 'workflow_id={self._workflow_id}, ' - 'execution_id={self._execution_id})' + 'workflow_name={self._workflow_name}, ' .format(name=self.__class__.__name__, self=self)) @property @@ -79,7 +76,7 @@ def blueprint(self): """ The blueprint model """ - return self.model.blueprint.get(self.deployment.blueprint_id) + return self.deployment.blueprint @property def deployment(self): @@ -88,20 +85,6 @@ def deployment(self): """ return self.model.deployment.get(self._deployment_id) - @property - def execution(self): - """ - The execution model - """ - return self.model.execution.get(self._execution_id) - - @execution.setter - def execution(self, value): - """ - Store the execution in the model storage - """ - self.model.execution.store(value) - @property def name(self): """ @@ -136,6 +119,6 @@ def get_resource(self, path=None): Read a deployment resource as string from the resource storage """ try: - return self.resource.deployment.data(entry_id=self.deployment.id, path=path) + return self.resource.deployment.read(entry_id=self.deployment.id, path=path) except exceptions.StorageError: - return self.resource.blueprint.data(entry_id=self.blueprint.id, path=path) + return self.resource.blueprint.read(entry_id=self.blueprint.id, path=path) diff --git a/aria/orchestrator/context/exceptions.py b/aria/orchestrator/context/exceptions.py index 6704bbc2..fe762e13 100644 --- a/aria/orchestrator/context/exceptions.py +++ b/aria/orchestrator/context/exceptions.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +Context based exceptions +""" from ..exceptions import OrchestratorError diff --git a/aria/orchestrator/context/operation.py b/aria/orchestrator/context/operation.py index bf3686de..a73bad1d 100644 --- a/aria/orchestrator/context/operation.py +++ b/aria/orchestrator/context/operation.py @@ -26,17 +26,17 @@ class BaseOperationContext(BaseContext): Context object used during operation creation and execution """ - def __init__(self, name, workflow_context, task, **kwargs): + def __init__(self, name, workflow_context, task, actor, **kwargs): super(BaseOperationContext, self).__init__( name=name, model_storage=workflow_context.model, resource_storage=workflow_context.resource, deployment_id=workflow_context._deployment_id, - workflow_id=workflow_context._workflow_id, - execution_id=workflow_context._execution_id, + workflow_name=workflow_context._workflow_name, **kwargs) self._task_model = task - self._actor = self.task.actor + self._task_id = task.id + self._actor_id = actor.id def __repr__(self): details = 'operation_mapping={task.operation_mapping}; ' \ @@ -50,7 +50,7 @@ def task(self): The task in the model storage :return: Task model """ - return self._task_model + return self.model.task.get(self._task_id) class NodeOperationContext(BaseOperationContext): @@ -63,7 +63,7 @@ def node(self): the node of the current operation :return: """ - return self._actor.node + return self.node_instance.node @property def node_instance(self): @@ -71,7 +71,7 @@ def node_instance(self): The node instance of the current operation :return: """ - return self._actor + return self.model.node_instance.get(self._actor_id) class RelationshipOperationContext(BaseOperationContext): @@ -84,7 +84,7 @@ def source_node(self): The source node :return: """ - return self.model.node.get(self.relationship.source_id) + return self.relationship.source_node @property def source_node_instance(self): @@ -92,7 +92,7 @@ def source_node_instance(self): The source node instance :return: """ - return self.model.node_instance.get(self.relationship_instance.source_id) + return self.relationship_instance.source_node_instance @property def target_node(self): @@ -100,7 +100,7 @@ def target_node(self): The target node :return: """ - return self.model.node.get(self.relationship.target_id) + return self.relationship.target_node @property def target_node_instance(self): @@ -108,7 +108,7 @@ def target_node_instance(self): The target node instance :return: """ - return self.model.node_instance.get(self._actor.target_id) + return self.relationship_instance.target_node_instance @property def relationship(self): @@ -116,7 +116,8 @@ def relationship(self): The relationship of the current operation :return: """ - return self._actor.relationship + + return self.relationship_instance.relationship @property def relationship_instance(self): @@ -124,4 +125,4 @@ def relationship_instance(self): The relationship instance of the current operation :return: """ - return self._actor + return self.model.relationship_instance.get(self._actor_id) diff --git a/aria/orchestrator/context/toolbelt.py b/aria/orchestrator/context/toolbelt.py index 0aad89cb..301b013c 100644 --- a/aria/orchestrator/context/toolbelt.py +++ b/aria/orchestrator/context/toolbelt.py @@ -26,21 +26,6 @@ class NodeToolBelt(object): def __init__(self, operation_context): self._op_context = operation_context - @property - def dependent_node_instances(self): - """ - Any node instance which has a relationship to the current node instance. - :return: - """ - assert isinstance(self._op_context, operation.NodeOperationContext) - node_instances = self._op_context.model.node_instance.iter( - filters={'deployment_id': self._op_context.deployment.id} - ) - for node_instance in node_instances: - for relationship_instance in node_instance.relationship_instances: - if relationship_instance.target_id == self._op_context.node_instance.id: - yield node_instance - @property def host_ip(self): """ @@ -48,9 +33,8 @@ def host_ip(self): :return: """ assert isinstance(self._op_context, operation.NodeOperationContext) - host_id = self._op_context._actor.host_id - host_instance = self._op_context.model.node_instance.get(host_id) - return host_instance.runtime_properties.get('ip') + host = self._op_context.node_instance.host + return host.runtime_properties.get('ip') class RelationshipToolBelt(object): diff --git a/aria/orchestrator/context/workflow.py b/aria/orchestrator/context/workflow.py index 3dc222b4..e2e8e252 100644 --- a/aria/orchestrator/context/workflow.py +++ b/aria/orchestrator/context/workflow.py @@ -19,8 +19,7 @@ import threading from contextlib import contextmanager - -from aria import storage +from datetime import datetime from .exceptions import ContextException from .common import BaseContext @@ -30,53 +29,73 @@ class WorkflowContext(BaseContext): """ Context object used during workflow creation and execution """ - def __init__(self, parameters=None, *args, **kwargs): + def __init__(self, parameters=None, execution_id=None, *args, **kwargs): super(WorkflowContext, self).__init__(*args, **kwargs) self.parameters = parameters or {} # TODO: execution creation should happen somewhere else # should be moved there, when such logical place exists - try: - self.model.execution.get(self._execution_id) - except storage.exceptions.StorageError: - self._create_execution() + self._execution_id = self._create_execution() if execution_id is None else execution_id def __repr__(self): return ( '{name}(deployment_id={self._deployment_id}, ' - 'workflow_id={self._workflow_id}, ' - 'execution_id={self._execution_id})'.format( + 'workflow_name={self._workflow_name}'.format( name=self.__class__.__name__, self=self)) def _create_execution(self): execution_cls = self.model.execution.model_cls + now = datetime.utcnow() execution = self.model.execution.model_cls( - id=self._execution_id, - deployment_id=self.deployment.id, - workflow_id=self._workflow_id, blueprint_id=self.blueprint.id, + deployment_id=self.deployment.id, + workflow_name=self._workflow_name, + created_at=now, status=execution_cls.PENDING, parameters=self.parameters, ) - self.model.execution.store(execution) + self.model.execution.put(execution) + return execution.id + + @property + def execution(self): + """ + The execution model + """ + return self.model.execution.get(self._execution_id) + + @execution.setter + def execution(self, value): + """ + Store the execution in the model storage + """ + self.model.execution.put(value) @property def nodes(self): """ Iterator over nodes """ - return self.model.node.iter(filters={'blueprint_id': self.blueprint.id}) + return self.model.node.iter( + filters={ + 'deployment_id': self.deployment.id + } + ) @property def node_instances(self): """ Iterator over node instances """ - return self.model.node_instance.iter(filters={'deployment_id': self.deployment.id}) + return self.model.node_instance.iter( + filters={ + 'deployment_id': self.deployment.id + } + ) class _CurrentContext(threading.local): """ - Provides thread-level context, which sugarcoats the task api. + Provides thread-level context, which sugarcoats the task mapi. """ def __init__(self): diff --git a/aria/orchestrator/exceptions.py b/aria/orchestrator/exceptions.py index 75b37cf9..1a481945 100644 --- a/aria/orchestrator/exceptions.py +++ b/aria/orchestrator/exceptions.py @@ -12,9 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +Orchestrator based exceptions +""" from aria.exceptions import AriaError class OrchestratorError(AriaError): + """ + Orchestrator based exception + """ pass diff --git a/aria/orchestrator/workflows/api/task.py b/aria/orchestrator/workflows/api/task.py index 4d36725b..1c12407e 100644 --- a/aria/orchestrator/workflows/api/task.py +++ b/aria/orchestrator/workflows/api/task.py @@ -18,7 +18,7 @@ """ from uuid import uuid4 -import aria +from aria.storage import models from ... import context from .. import exceptions @@ -75,8 +75,8 @@ def __init__(self, :param actor: the operation host on which this operation is registered. :param inputs: operation inputs. """ - assert isinstance(actor, (aria.storage.models.NodeInstance, - aria.storage.models.RelationshipInstance)) + assert isinstance(actor, (models.NodeInstance, + models.RelationshipInstance)) super(OperationTask, self).__init__() self.actor = actor self.name = '{name}.{actor.id}'.format(name=name, actor=actor) @@ -97,7 +97,7 @@ def node_instance(cls, instance, name, inputs=None, *args, **kwargs): :param instance: the node of which this operation belongs to. :param name: the name of the operation. """ - assert isinstance(instance, aria.storage.models.NodeInstance) + assert isinstance(instance, models.NodeInstance) operation_details = instance.node.operations[name] operation_inputs = operation_details.get('inputs', {}) operation_inputs.update(inputs or {}) @@ -119,7 +119,7 @@ def relationship_instance(cls, instance, name, operation_end, inputs=None, *args with 'source_operations' and 'target_operations' :param inputs any additional inputs to the operation """ - assert isinstance(instance, aria.storage.models.RelationshipInstance) + assert isinstance(instance, models.RelationshipInstance) if operation_end not in [cls.TARGET_OPERATION, cls.SOURCE_OPERATION]: raise exceptions.TaskException('The operation end should be {0} or {1}'.format( cls.TARGET_OPERATION, cls.SOURCE_OPERATION diff --git a/aria/orchestrator/workflows/builtin/heal.py b/aria/orchestrator/workflows/builtin/heal.py index dbfc14e9..de070956 100644 --- a/aria/orchestrator/workflows/builtin/heal.py +++ b/aria/orchestrator/workflows/builtin/heal.py @@ -84,16 +84,19 @@ def heal_uninstall(ctx, graph, failing_node_instances, targeted_node_instances): # create dependencies between the node instance sub workflow for node_instance in failing_node_instances: node_instance_sub_workflow = node_instance_sub_workflows[node_instance.id] - for relationship_instance in reversed(node_instance.relationship_instances): - graph.add_dependency(node_instance_sub_workflows[relationship_instance.target_id], - node_instance_sub_workflow) + for relationship_instance in reversed(node_instance.outbound_relationship_instances): + graph.add_dependency( + node_instance_sub_workflows[relationship_instance.target_node_instance.id], + node_instance_sub_workflow) # Add operations for intact nodes depending on a node instance belonging to node_instances for node_instance in targeted_node_instances: node_instance_sub_workflow = node_instance_sub_workflows[node_instance.id] - for relationship_instance in reversed(node_instance.relationship_instances): - target_node_instance = ctx.model.node_instance.get(relationship_instance.target_id) + for relationship_instance in reversed(node_instance.outbound_relationship_instances): + + target_node_instance = \ + ctx.model.node_instance.get(relationship_instance.target_node_instance.id) target_node_instance_subgraph = node_instance_sub_workflows[target_node_instance.id] graph.add_dependency(target_node_instance_subgraph, node_instance_sub_workflow) @@ -134,9 +137,10 @@ def heal_install(ctx, graph, failing_node_instances, targeted_node_instances): # create dependencies between the node instance sub workflow for node_instance in failing_node_instances: node_instance_sub_workflow = node_instance_sub_workflows[node_instance.id] - if node_instance.relationship_instances: - dependencies = [node_instance_sub_workflows[relationship_instance.target_id] - for relationship_instance in node_instance.relationship_instances] + if node_instance.outbound_relationship_instances: + dependencies = \ + [node_instance_sub_workflows[relationship_instance.target_node_instance.id] + for relationship_instance in node_instance.outbound_relationship_instances] graph.add_dependency(node_instance_sub_workflow, dependencies) # Add operations for intact nodes depending on a node instance @@ -144,8 +148,9 @@ def heal_install(ctx, graph, failing_node_instances, targeted_node_instances): for node_instance in targeted_node_instances: node_instance_sub_workflow = node_instance_sub_workflows[node_instance.id] - for relationship_instance in node_instance.relationship_instances: - target_node_instance = ctx.model.node_instance.get(relationship_instance.target_id) + for relationship_instance in node_instance.outbound_relationship_instances: + target_node_instance = ctx.model.node_instance.get( + relationship_instance.target_node_instance.id) target_node_instance_subworkflow = node_instance_sub_workflows[target_node_instance.id] graph.add_dependency(node_instance_sub_workflow, target_node_instance_subworkflow) diff --git a/aria/orchestrator/workflows/builtin/install.py b/aria/orchestrator/workflows/builtin/install.py index 0ab3ad61..eb5b4e8e 100644 --- a/aria/orchestrator/workflows/builtin/install.py +++ b/aria/orchestrator/workflows/builtin/install.py @@ -47,7 +47,8 @@ def install(ctx, graph, node_instances=(), node_instance_sub_workflows=None): # create dependencies between the node instance sub workflow for node_instance in node_instances: node_instance_sub_workflow = node_instance_sub_workflows[node_instance.id] - if node_instance.relationship_instances: - dependencies = [node_instance_sub_workflows[relationship_instance.target_id] - for relationship_instance in node_instance.relationship_instances] + if node_instance.outbound_relationship_instances: + dependencies = [ + node_instance_sub_workflows[relationship_instance.target_node_instance.id] + for relationship_instance in node_instance.outbound_relationship_instances] graph.add_dependency(node_instance_sub_workflow, dependencies) diff --git a/aria/orchestrator/workflows/builtin/uninstall.py b/aria/orchestrator/workflows/builtin/uninstall.py index f4e965c4..db1c0ccb 100644 --- a/aria/orchestrator/workflows/builtin/uninstall.py +++ b/aria/orchestrator/workflows/builtin/uninstall.py @@ -27,7 +27,7 @@ def uninstall(ctx, graph, node_instances=(), node_instance_sub_workflows=None): """ The uninstall workflow - :param WorkflowContext context: the workflow context + :param WorkflowContext ctx: the workflow context :param TaskGraph graph: the graph which will describe the workflow. :param node_instances: the node instances on which to run the workflow :param dict node_instance_sub_workflows: a dictionary of subworkflows with id as key and @@ -47,6 +47,7 @@ def uninstall(ctx, graph, node_instances=(), node_instance_sub_workflows=None): # create dependencies between the node instance sub workflow for node_instance in node_instances: node_instance_sub_workflow = node_instance_sub_workflows[node_instance.id] - for relationship_instance in reversed(node_instance.relationship_instances): - graph.add_dependency(node_instance_sub_workflows[relationship_instance.target_id], + for relationship_instance in reversed(node_instance.outbound_relationship_instances): + target_id = relationship_instance.target_node_instance.id + graph.add_dependency(node_instance_sub_workflows[target_id], node_instance_sub_workflow) diff --git a/aria/orchestrator/workflows/builtin/workflows.py b/aria/orchestrator/workflows/builtin/workflows.py index 0eb8c346..4f765b37 100644 --- a/aria/orchestrator/workflows/builtin/workflows.py +++ b/aria/orchestrator/workflows/builtin/workflows.py @@ -37,7 +37,6 @@ def install_node_instance(graph, node_instance, **kwargs): """ A workflow which installs a node instance. - :param WorkflowContext ctx: the workflow context :param TaskGraph graph: the tasks graph of which to edit :param node_instance: the node instance to install :return: @@ -68,7 +67,6 @@ def install_node_instance(graph, node_instance, **kwargs): def preconfigure_relationship(graph, node_instance, **kwargs): """ - :param context: :param graph: :param node_instance: :return: @@ -82,7 +80,6 @@ def preconfigure_relationship(graph, node_instance, **kwargs): def postconfigure_relationship(graph, node_instance, **kwargs): """ - :param context: :param graph: :param node_instance: :return: @@ -96,7 +93,6 @@ def postconfigure_relationship(graph, node_instance, **kwargs): def establish_relationship(graph, node_instance, **kwargs): """ - :param context: :param graph: :param node_instance: :return: @@ -113,7 +109,6 @@ def establish_relationship(graph, node_instance, **kwargs): def uninstall_node_instance(graph, node_instance, **kwargs): """ A workflow which uninstalls a node instance. - :param WorkflowContext context: the workflow context :param TaskGraph graph: the tasks graph of which to edit :param node_instance: the node instance to uninstall :return: @@ -135,7 +130,6 @@ def uninstall_node_instance(graph, node_instance, **kwargs): def unlink_relationship(graph, node_instance): """ - :param context: :param graph: :param node_instance: :return: @@ -179,8 +173,8 @@ def relationships_tasks(graph, operation_name, node_instance): :return: """ relationships_groups = groupby( - node_instance.relationship_instances, - key=lambda relationship_instance: relationship_instance.relationship.target_id) + node_instance.outbound_relationship_instances, + key=lambda relationship_instance: relationship_instance.target_node_instance.id) sub_tasks = [] for _, (_, relationship_group) in enumerate(relationships_groups): @@ -196,11 +190,8 @@ def relationships_tasks(graph, operation_name, node_instance): def relationship_tasks(relationship_instance, operation_name): """ Creates a relationship task source and target. - :param NodeInstance node_instance: the node instance of the relationship :param RelationshipInstance relationship_instance: the relationship instance itself - :param WorkflowContext context: :param operation_name: - :param index: the relationship index - enables pretty print :return: """ source_operation = task.OperationTask.relationship_instance( diff --git a/aria/orchestrator/workflows/core/engine.py b/aria/orchestrator/workflows/core/engine.py index 87ea8c69..2d26aebd 100644 --- a/aria/orchestrator/workflows/core/engine.py +++ b/aria/orchestrator/workflows/core/engine.py @@ -100,7 +100,11 @@ def _all_tasks_consumed(self): return len(self._execution_graph.node) == 0 def _tasks_iter(self): - return (data['task'] for _, data in self._execution_graph.nodes_iter(data=True)) + for _, data in self._execution_graph.nodes_iter(data=True): + task = data['task'] + if isinstance(task, engine_task.OperationTask): + self._workflow_context.model.task.refresh(task.model_task) + yield task def _handle_executable_task(self, task): if isinstance(task, engine_task.StubTask): diff --git a/aria/orchestrator/workflows/core/task.py b/aria/orchestrator/workflows/core/task.py index a583cfce..0be17fe7 100644 --- a/aria/orchestrator/workflows/core/task.py +++ b/aria/orchestrator/workflows/core/task.py @@ -106,32 +106,34 @@ class OperationTask(BaseTask): def __init__(self, api_task, *args, **kwargs): super(OperationTask, self).__init__(id=api_task.id, **kwargs) self._workflow_context = api_task._workflow_context - task_model = api_task._workflow_context.model.task.model_cls - operation_task = task_model( - id=api_task.id, - name=api_task.name, - operation_mapping=api_task.operation_mapping, - actor=api_task.actor, - inputs=api_task.inputs, - status=task_model.PENDING, - execution_id=self._workflow_context._execution_id, - max_attempts=api_task.max_attempts, - retry_interval=api_task.retry_interval, - ignore_failure=api_task.ignore_failure - ) + base_task_model = api_task._workflow_context.model.task.model_cls if isinstance(api_task.actor, models.NodeInstance): context_class = operation_context.NodeOperationContext + task_model_cls = base_task_model.as_node_instance elif isinstance(api_task.actor, models.RelationshipInstance): context_class = operation_context.RelationshipOperationContext + task_model_cls = base_task_model.as_relationship_instance else: - raise RuntimeError('No operation context could be created for {0}' - .format(api_task.actor.model_cls)) + raise RuntimeError('No operation context could be created for {actor.model_cls}' + .format(actor=api_task.actor)) + + operation_task = task_model_cls( + name=api_task.name, + operation_mapping=api_task.operation_mapping, + instance_id=api_task.actor.id, + inputs=api_task.inputs, + status=base_task_model.PENDING, + max_attempts=api_task.max_attempts, + retry_interval=api_task.retry_interval, + ignore_failure=api_task.ignore_failure, + ) + self._workflow_context.model.task.put(operation_task) self._ctx = context_class(name=api_task.name, workflow_context=self._workflow_context, - task=operation_task) - self._workflow_context.model.task.store(operation_task) + task=operation_task, + actor=operation_task.actor) self._task_id = operation_task.id self._update_fields = None @@ -161,7 +163,7 @@ def model_task(self): @model_task.setter def model_task(self, value): - self._workflow_context.model.task.store(value) + self._workflow_context.model.task.put(value) @property def context(self): diff --git a/aria/storage/__init__.py b/aria/storage/__init__.py index 2d142a5d..fd69d47d 100644 --- a/aria/storage/__init__.py +++ b/aria/storage/__init__.py @@ -20,14 +20,14 @@ Storage package is a generic abstraction over different storage types. We define this abstraction with the following components: -1. storage: simple api to use -2. driver: implementation of the database client api. +1. storage: simple mapi to use +2. driver: implementation of the database client mapi. 3. model: defines the structure of the table/document. 4. field: defines a field/item in the model. API: * application_storage_factory - function, default Aria storage factory. - * Storage - class, simple storage api. + * Storage - class, simple storage mapi. * models - module, default Aria standard models. * structures - module, default Aria structures - holds the base model, and different fields types. @@ -37,354 +37,28 @@ * drivers - module, a pool of Aria standard drivers. * StorageDriver - class, abstract model implementation. """ -# todo: rewrite the above package documentation -# (something like explaning the two types of storage - models and resources) - -from collections import namedtuple - -from .structures import Storage, Field, Model, IterField, PointerField -from .drivers import ( - ModelDriver, - ResourceDriver, - FileSystemResourceDriver, - FileSystemModelDriver, +from .core import ( + Storage, + ModelStorage, + ResourceStorage, +) +from . import ( + exceptions, + api, + structures, + core, + filesystem_rapi, + sql_mapi, + models ) -from . import models, exceptions __all__ = ( - 'ModelStorage', - 'ResourceStorage', - 'FileSystemModelDriver', - 'models', + 'exceptions', 'structures', - 'Field', - 'IterField', - 'PointerField', - 'Model', - 'drivers', - 'ModelDriver', - 'ResourceDriver', - 'FileSystemResourceDriver', + # 'Storage', + # 'ModelStorage', + # 'ResourceStorage', + 'filesystem_rapi', + 'sql_mapi', + 'api' ) -# todo: think about package output api's... -# todo: in all drivers name => entry_type -# todo: change in documentation str => basestring - - -class ModelStorage(Storage): - """ - Managing the models storage. - """ - def __init__(self, driver, model_classes=(), **kwargs): - """ - Simple storage client api for Aria applications. - The storage instance defines the tables/documents/code api. - - :param ModelDriver driver: model storage driver. - :param model_classes: the models to register. - """ - assert isinstance(driver, ModelDriver) - super(ModelStorage, self).__init__(driver, model_classes, **kwargs) - - def __getattr__(self, table): - """ - getattr is a shortcut to simple api - - for Example: - >> storage = ModelStorage(driver=FileSystemModelDriver('/tmp')) - >> node_table = storage.node - >> for node in node_table: - >> print node - - :param str table: table name to get - :return: a storage object that mapped to the table name - """ - return super(ModelStorage, self).__getattr__(table) - - def register(self, model_cls): - """ - Registers the model type in the resource storage manager. - :param model_cls: the model to register. - """ - model_name = generate_lower_name(model_cls) - model_api = _ModelApi(model_name, self.driver, model_cls) - self.registered[model_name] = model_api - - for pointer_schema_register in model_api.pointer_mapping.values(): - model_cls = pointer_schema_register.model_cls - self.register(model_cls) - -_Pointer = namedtuple('_Pointer', 'name, is_iter') - - -class _ModelApi(object): - def __init__(self, name, driver, model_cls): - """ - Managing the model in the storage, using the driver. - - :param basestring name: the name of the model. - :param ModelDriver driver: the driver which supports this model in the storage. - :param Model model_cls: table/document class model. - """ - assert isinstance(driver, ModelDriver) - assert issubclass(model_cls, Model) - self.name = name - self.driver = driver - self.model_cls = model_cls - self.pointer_mapping = {} - self._setup_pointers_mapping() - - def _setup_pointers_mapping(self): - for field_name, field_cls in vars(self.model_cls).items(): - if not(isinstance(field_cls, PointerField) and field_cls.type): - continue - pointer_key = _Pointer(field_name, is_iter=isinstance(field_cls, IterField)) - self.pointer_mapping[pointer_key] = self.__class__( - name=generate_lower_name(field_cls.type), - driver=self.driver, - model_cls=field_cls.type) - - def __iter__(self): - return self.iter() - - def __repr__(self): - return '{self.name}(driver={self.driver}, model={self.model_cls})'.format(self=self) - - def create(self): - """ - Creates the model in the storage. - """ - with self.driver as connection: - connection.create(self.name) - - def get(self, entry_id, **kwargs): - """ - Getter for the model from the storage. - - :param basestring entry_id: the id of the table/document. - :return: model instance - :rtype: Model - """ - with self.driver as connection: - data = connection.get( - name=self.name, - entry_id=entry_id, - **kwargs) - data.update(self._get_pointers(data, **kwargs)) - return self.model_cls(**data) - - def store(self, entry, **kwargs): - """ - Setter for the model in the storage. - - :param Model entry: the table/document to store. - """ - assert isinstance(entry, self.model_cls) - with self.driver as connection: - data = entry.fields_dict - data.update(self._store_pointers(data, **kwargs)) - connection.store( - name=self.name, - entry_id=entry.id, - entry=data, - **kwargs) - - def delete(self, entry_id, **kwargs): - """ - Delete the model from storage. - - :param basestring entry_id: id of the entity to delete from storage. - """ - entry = self.get(entry_id) - with self.driver as connection: - self._delete_pointers(entry, **kwargs) - connection.delete( - name=self.name, - entry_id=entry_id, - **kwargs) - - def iter(self, **kwargs): - """ - Generator over the entries of model in storage. - """ - with self.driver as connection: - for data in connection.iter(name=self.name, **kwargs): - data.update(self._get_pointers(data, **kwargs)) - yield self.model_cls(**data) - - def update(self, entry_id, **kwargs): - """ - Updates and entry in storage. - - :param str entry_id: the id of the table/document. - :param kwargs: the fields to update. - :return: - """ - with self.driver as connection: - connection.update( - name=self.name, - entry_id=entry_id, - **kwargs - ) - - def _get_pointers(self, data, **kwargs): - pointers = {} - for field, schema in self.pointer_mapping.items(): - if field.is_iter: - pointers[field.name] = [ - schema.get(entry_id=pointer_id, **kwargs) - for pointer_id in data[field.name] - if pointer_id] - elif data[field.name]: - pointers[field.name] = schema.get(entry_id=data[field.name], **kwargs) - return pointers - - def _store_pointers(self, data, **kwargs): - pointers = {} - for field, model_api in self.pointer_mapping.items(): - if field.is_iter: - pointers[field.name] = [] - for iter_entity in data[field.name]: - pointers[field.name].append(iter_entity.id) - model_api.store(iter_entity, **kwargs) - else: - pointers[field.name] = data[field.name].id - model_api.store(data[field.name], **kwargs) - return pointers - - def _delete_pointers(self, entry, **kwargs): - for field, schema in self.pointer_mapping.items(): - if field.is_iter: - for iter_entry in getattr(entry, field.name): - schema.delete(iter_entry.id, **kwargs) - else: - schema.delete(getattr(entry, field.name).id, **kwargs) - - -class ResourceApi(object): - """ - Managing the resource in the storage, using the driver. - - :param basestring name: the name of the resource. - :param ResourceDriver driver: the driver which supports this resource in the storage. - """ - def __init__(self, driver, resource_name): - """ - Managing the resources in the storage, using the driver. - - :param ResourceDriver driver: the driver which supports this model in the storage. - :param basestring resource_name: the type of the entry this resourceAPI manages. - """ - assert isinstance(driver, ResourceDriver) - self.driver = driver - self.resource_name = resource_name - - def __repr__(self): - return '{name}(driver={self.driver}, resource={self.resource_name})'.format( - name=self.__class__.__name__, self=self) - - def create(self): - """ - Create the resource dir in the storage. - """ - with self.driver as connection: - connection.create(self.resource_name) - - def data(self, entry_id, path=None, **kwargs): - """ - Retrieve the content of a storage resource. - - :param basestring entry_id: the id of the entry. - :param basestring path: path of the resource on the storage. - :param kwargs: resources to be passed to the driver.. - :return the content of a single file: - """ - with self.driver as connection: - return connection.data( - entry_type=self.resource_name, - entry_id=entry_id, - path=path, - **kwargs) - - def download(self, entry_id, destination, path=None, **kwargs): - """ - Download a file/dir from the resource storage. - - :param basestring entry_id: the id of the entry. - :param basestring destination: the destination of the file/dir. - :param basestring path: path of the resource on the storage. - """ - with self.driver as connection: - connection.download( - entry_type=self.resource_name, - entry_id=entry_id, - destination=destination, - path=path, - **kwargs) - - def upload(self, entry_id, source, path=None, **kwargs): - """ - Upload a file/dir from the resource storage. - - :param basestring entry_id: the id of the entry. - :param basestring source: the source path of the file to upload. - :param basestring path: the destination of the file, relative to the root dir - of the resource - """ - with self.driver as connection: - connection.upload( - entry_type=self.resource_name, - entry_id=entry_id, - source=source, - path=path, - **kwargs) - - -def generate_lower_name(model_cls): - """ - Generates the name of the class from the class object. e.g. SomeClass -> some_class - :param model_cls: the class to evaluate. - :return: lower name - :rtype: basestring - """ - return ''.join( - character if character.islower() else '_{0}'.format(character.lower()) - for character in model_cls.__name__)[1:] - - -class ResourceStorage(Storage): - """ - Managing the resource storage. - """ - def __init__(self, driver, resources=(), **kwargs): - """ - Simple storage client api for Aria applications. - The storage instance defines the tables/documents/code api. - - :param ResourceDriver driver: resource storage driver - :param resources: the resources to register. - """ - assert isinstance(driver, ResourceDriver) - super(ResourceStorage, self).__init__(driver, resources, **kwargs) - - def register(self, resource): - """ - Registers the resource type in the resource storage manager. - :param resource: the resource to register. - """ - self.registered[resource] = ResourceApi(self.driver, resource_name=resource) - - def __getattr__(self, resource): - """ - getattr is a shortcut to simple api - - for Example: - >> storage = ResourceStorage(driver=FileSystemResourceDriver('/tmp')) - >> blueprint_resources = storage.blueprint - >> blueprint_resources.download(blueprint_id, destination='~/blueprint/') - - :param str resource: resource name to download - :return: a storage object that mapped to the resource name - :rtype: ResourceApi - """ - return super(ResourceStorage, self).__getattr__(resource) diff --git a/aria/storage/api.py b/aria/storage/api.py new file mode 100644 index 00000000..d6fc3b8d --- /dev/null +++ b/aria/storage/api.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +General storage API +""" + + +class StorageAPI(object): + """ + General storage Base API + """ + def create(self, **kwargs): + """ + Create a storage API. + :param kwargs: + :return: + """ + raise NotImplementedError('Subclass must implement abstract create method') + + +class ModelAPI(StorageAPI): + """ + A Base object for the model. + """ + def __init__(self, model_cls, name=None, **kwargs): + """ + Base model API + + :param model_cls: the representing class of the model + :param str name: the name of the model + :param kwargs: + """ + super(ModelAPI, self).__init__(**kwargs) + self._model_cls = model_cls + self._name = name or generate_lower_name(model_cls) + + @property + def name(self): + """ + The name of the class + :return: name of the class + """ + return self._name + + @property + def model_cls(self): + """ + The class represting the model + :return: + """ + return self._model_cls + + def get(self, entry_id, filters=None, **kwargs): + """ + Get entry from storage. + + :param entry_id: + :param kwargs: + :return: + """ + raise NotImplementedError('Subclass must implement abstract get method') + + def put(self, entry, **kwargs): + """ + Store entry in storage + + :param entry: + :param kwargs: + :return: + """ + raise NotImplementedError('Subclass must implement abstract store method') + + def delete(self, entry_id, **kwargs): + """ + Delete entry from storage. + + :param entry_id: + :param kwargs: + :return: + """ + raise NotImplementedError('Subclass must implement abstract delete method') + + def __iter__(self): + return self.iter() + + def iter(self, **kwargs): + """ + Iter over the entries in storage. + + :param kwargs: + :return: + """ + raise NotImplementedError('Subclass must implement abstract iter method') + + def update(self, entry, **kwargs): + """ + Update entry in storage. + + :param entry: + :param kwargs: + :return: + """ + raise NotImplementedError('Subclass must implement abstract update method') + + +class ResourceAPI(StorageAPI): + """ + A Base object for the resource. + """ + def __init__(self, name): + """ + Base resource API + :param str name: the resource type + """ + self._name = name + + @property + def name(self): + """ + The name of the resource + :return: + """ + return self._name + + def read(self, entry_id, path=None, **kwargs): + """ + Get a bytesteam from the storage. + + :param entry_id: + :param path: + :param kwargs: + :return: + """ + raise NotImplementedError('Subclass must implement abstract data method') + + def download(self, entry_id, destination, path=None, **kwargs): + """ + Download a resource from the storage. + + :param entry_id: + :param destination: + :param path: + :param kwargs: + :return: + """ + raise NotImplementedError('Subclass must implement abstract download method') + + def upload(self, entry_id, source, path=None, **kwargs): + """ + Upload a resource to the storage. + + :param entry_id: + :param source: + :param path: + :param kwargs: + :return: + """ + raise NotImplementedError('Subclass must implement abstract upload method') + + +def generate_lower_name(model_cls): + """ + Generates the name of the class from the class object. e.g. SomeClass -> some_class + :param model_cls: the class to evaluate. + :return: lower name + :rtype: basestring + """ + return ''.join( + character if character.islower() else '_{0}'.format(character.lower()) + for character in model_cls.__name__)[1:] diff --git a/aria/storage/core.py b/aria/storage/core.py new file mode 100644 index 00000000..a5d3210d --- /dev/null +++ b/aria/storage/core.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Aria's storage Sub-Package +Path: aria.storage + +Storage package is a generic abstraction over different storage types. +We define this abstraction with the following components: + +1. storage: simple mapi to use +2. driver: implementation of the database client mapi. +3. model: defines the structure of the table/document. +4. field: defines a field/item in the model. + +API: + * application_storage_factory - function, default Aria storage factory. + * Storage - class, simple storage mapi. + * models - module, default Aria standard models. + * structures - module, default Aria structures - holds the base model, + and different fields types. + * Model - class, abstract model implementation. + * Field - class, base field implementation. + * IterField - class, base iterable field implementation. + * drivers - module, a pool of Aria standard drivers. + * StorageDriver - class, abstract model implementation. +""" + +from aria.logger import LoggerMixin +from . import api as storage_api + +__all__ = ( + 'Storage', + 'ModelStorage', + 'ResourceStorage' +) + + +class Storage(LoggerMixin): + """ + Represents the storage + """ + def __init__(self, api_cls, api_kwargs=None, items=(), **kwargs): + self._api_kwargs = api_kwargs or {} + super(Storage, self).__init__(**kwargs) + self.api = api_cls + self.registered = {} + for item in items: + self.register(item) + self.logger.debug('{name} object is ready: {0!r}'.format( + self, name=self.__class__.__name__)) + + def __repr__(self): + return '{name}(api={self.api})'.format(name=self.__class__.__name__, self=self) + + def __getattr__(self, item): + try: + return self.registered[item] + except KeyError: + return super(Storage, self).__getattribute__(item) + + def register(self, entry): + """ + Register the entry to the storage + :param name: + :return: + """ + raise NotImplementedError('Subclass must implement abstract register method') + + +class ResourceStorage(Storage): + """ + Represents resource storage. + """ + def register(self, name): + """ + Register the resource type to resource storage. + :param name: + :return: + """ + self.registered[name] = self.api(name=name, **self._api_kwargs) + self.registered[name].create() + self.logger.debug('setup {name} in storage {self!r}'.format(name=name, self=self)) + + +class ModelStorage(Storage): + """ + Represents model storage. + """ + def register(self, model_cls): + """ + Register the model into the model storage. + :param model_cls: the model to register. + :return: + """ + model_name = storage_api.generate_lower_name(model_cls) + if model_name in self.registered: + self.logger.debug('{name} in already storage {self!r}'.format(name=model_name, + self=self)) + return + self.registered[model_name] = self.api(name=model_name, + model_cls=model_cls, + **self._api_kwargs) + self.registered[model_name].create() + self.logger.debug('setup {name} in storage {self!r}'.format(name=model_name, self=self)) + + def drop(self): + """ + Drop all the tables from the model. + :return: + """ + for mapi in self.registered.values(): + mapi.drop() diff --git a/aria/storage/drivers.py b/aria/storage/drivers.py deleted file mode 100644 index 1f96956c..00000000 --- a/aria/storage/drivers.py +++ /dev/null @@ -1,416 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Aria's storage.drivers module -Path: aria.storage.driver - -drivers module holds a generic abstract implementation of drivers. - -classes: - * Driver - abstract storage driver implementation. - * ModelDriver - abstract model base storage driver. - * ResourceDriver - abstract resource base storage driver. - * FileSystemModelDriver - file system implementation for model storage driver. - * FileSystemResourceDriver - file system implementation for resource storage driver. -""" - -import distutils.dir_util # pylint: disable=no-name-in-module, import-error -import os -import shutil -from functools import partial -from multiprocessing import RLock - -import jsonpickle - -from ..logger import LoggerMixin -from .exceptions import StorageError - -__all__ = ( - 'ModelDriver', - 'FileSystemModelDriver', - 'ResourceDriver', - 'FileSystemResourceDriver', -) - - -class Driver(LoggerMixin): - """ - Driver: storage driver context manager - abstract driver implementation. - In the implementation level, It is a good practice to raise StorageError on Errors. - """ - - def __enter__(self): - """ - Context manager entry method, executes connect. - :return: context manager instance - :rtype: Driver - """ - self.connect() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Context manager exit method, executes disconnect. - """ - self.disconnect() - if not exc_type: - return - # self.logger.debug( - # '{name} had an error'.format(name=self.__class__.__name__), - # exc_info=(exc_type, exc_val, exc_tb)) - if StorageError in exc_type.mro(): - return - raise StorageError('Exception had occurred, {type}: {message}'.format( - type=exc_type, message=str(exc_val))) - - def connect(self): - """ - Open storage connection. - In some cases, This method can get the connection from a connection pool. - """ - pass - - def disconnect(self): - """ - Close storage connection. - In some cases, This method can release the connection to the connection pool. - """ - pass - - def create(self, name, *args, **kwargs): - """ - Create table/document in storage by name. - :param str name: name of table/document in storage. - """ - pass - - -class ModelDriver(Driver): - """ - ModelDriver context manager. - Base Driver for Model based storage. - """ - - def get(self, name, entry_id, **kwargs): - """ - Getter from storage. - :param str name: name of table/document in storage. - :param str entry_id: id of the document to get from storage. - :return: value of entity from the storage. - """ - raise NotImplementedError('Subclass must implement abstract get method') - - def delete(self, name, entry_id, **kwargs): - """ - Delete from storage. - :param str name: name of table/document in storage. - :param str entry_id: id of the entity to delete from storage. - :param dict kwargs: extra kwargs if needed. - """ - raise NotImplementedError('Subclass must implement abstract delete method') - - def store(self, name, entry_id, entry, **kwargs): - """ - Setter to storage. - :param str name: name of table/document in storage. - :param str entry_id: id of the entity to store in the storage. - :param dict entry: content to store. - """ - raise NotImplementedError('Subclass must implement abstract store method') - - def iter(self, name, **kwargs): - """ - Generator over the entries of table/document in storage. - :param str name: name of table/document/file in storage to iter over. - """ - raise NotImplementedError('Subclass must implement abstract iter method') - - def update(self, name, entry_id, **kwargs): - """ - Updates and entry in storage. - - :param str name: name of table/document in storage. - :param str entry_id: id of the document to get from storage. - :param kwargs: the fields to update. - :return: - """ - raise NotImplementedError('Subclass must implement abstract store method') - - -class ResourceDriver(Driver): - """ - ResourceDriver context manager. - Base Driver for Resource based storage. - - Resource storage structure is a file system base. - /// - entry: can be one single file or multiple files and directories. - """ - - def data(self, entry_type, entry_id, path=None, **kwargs): - """ - Get the binary data from a file in a resource entry. - If the entry is a single file no path needed, - If the entry contain number of files the path will gide to the relevant file. - - resource path: - /// - - :param basestring entry_type: resource name. - :param basestring entry_id: id of the entity to resource in the storage. - :param basestring path: path to resource relative to entry_id folder in the storage. - :return: entry file object. - :rtype: bytes - """ - raise NotImplementedError('Subclass must implement abstract get method') - - def download(self, entry_type, entry_id, destination, path=None, **kwargs): - """ - Download the resource to a destination. - Like data method bat this method isn't returning data, - Instead it create a new file in local file system. - - resource path: - /// - copy to: - / - destination can be file or directory - - :param basestring entry_type: resource name. - :param basestring entry_id: id of the entity to resource in the storage. - :param basestring destination: path in local file system to download to. - :param basestring path: path to resource relative to entry_id folder in the storage. - """ - raise NotImplementedError('Subclass must implement abstract get method') - - def upload(self, entry_type, entry_id, source, path=None, **kwargs): - """ - Upload the resource from source. - source can be file or directory with files. - - copy from: - / - to resource path: - /// - - :param basestring entry_type: resource name. - :param basestring entry_id: id of the entity to resource in the storage. - :param basestring source: source can be file or directory with files. - :param basestring path: path to resource relative to entry_id folder in the storage. - """ - raise NotImplementedError('Subclass must implement abstract get method') - - -class BaseFileSystemDriver(Driver): - """ - Base class which handles storage on the file system. - """ - def __init__(self, *args, **kwargs): - super(BaseFileSystemDriver, self).__init__(*args, **kwargs) - self._lock = RLock() - - def connect(self): - self._lock.acquire() - - def disconnect(self): - self._lock.release() - - def __getstate__(self): - obj_dict = super(BaseFileSystemDriver, self).__getstate__() - del obj_dict['_lock'] - return obj_dict - - def __setstate__(self, obj_dict): - super(BaseFileSystemDriver, self).__setstate__(obj_dict) - vars(self).update(_lock=RLock(), **obj_dict) - - -class FileSystemModelDriver(ModelDriver, BaseFileSystemDriver): - """ - FileSystemModelDriver context manager. - """ - - def __init__(self, directory, **kwargs): - """ - File system implementation for storage driver. - :param str directory: root dir for storage. - """ - super(FileSystemModelDriver, self).__init__(**kwargs) - self.directory = directory - - self._join_path = partial(os.path.join, self.directory) - - def __repr__(self): - return '{cls.__name__}(directory={self.directory})'.format( - cls=self.__class__, self=self) - - def create(self, name): - """ - Create directory in storage by path. - tries to create the root directory as well. - :param str name: path of file in storage. - """ - try: - os.makedirs(self.directory) - except (OSError, IOError): - pass - os.makedirs(self._join_path(name)) - - def get(self, name, entry_id, **kwargs): - """ - Getter from storage. - :param str name: name of directory in storage. - :param str entry_id: id of the file to get from storage. - :return: value of file from storage. - :rtype: dict - """ - with open(self._join_path(name, entry_id)) as file_obj: - return jsonpickle.loads(file_obj.read()) - - def store(self, name, entry_id, entry, **kwargs): - """ - Delete from storage. - :param str name: name of directory in storage. - :param str entry_id: id of the file to delete from storage. - """ - with open(self._join_path(name, entry_id), 'w') as file_obj: - file_obj.write(jsonpickle.dumps(entry)) - - def delete(self, name, entry_id, **kwargs): - """ - Delete from storage. - :param str name: name of directory in storage. - :param str entry_id: id of the file to delete from storage. - """ - os.remove(self._join_path(name, entry_id)) - - def iter(self, name, filters=None, **kwargs): - """ - Generator over the entries of directory in storage. - :param str name: name of directory in storage to iter over. - :param dict filters: filters for query - """ - filters = filters or {} - - for entry_id in os.listdir(self._join_path(name)): - value = self.get(name, entry_id=entry_id) - for filter_name, filter_value in filters.items(): - if value.get(filter_name) != filter_value: - break - else: - yield value - - def update(self, name, entry_id, **kwargs): - """ - Updates and entry in storage. - - :param str name: name of table/document in storage. - :param str entry_id: id of the document to get from storage. - :param kwargs: the fields to update. - :return: - """ - entry_dict = self.get(name, entry_id) - entry_dict.update(**kwargs) - self.store(name, entry_id, entry_dict) - - -class FileSystemResourceDriver(ResourceDriver, BaseFileSystemDriver): - """ - FileSystemResourceDriver context manager. - """ - - def __init__(self, directory, **kwargs): - """ - File system implementation for storage driver. - :param str directory: root dir for storage. - """ - super(FileSystemResourceDriver, self).__init__(**kwargs) - self.directory = directory - self._join_path = partial(os.path.join, self.directory) - - def __repr__(self): - return '{cls.__name__}(directory={self.directory})'.format( - cls=self.__class__, self=self) - - def create(self, name): - """ - Create directory in storage by path. - tries to create the root directory as well. - :param basestring name: path of file in storage. - """ - try: - os.makedirs(self.directory) - except (OSError, IOError): - pass - os.makedirs(self._join_path(name)) - - def data(self, entry_type, entry_id, path=None): - """ - Retrieve the content of a file system storage resource. - - :param basestring entry_type: the type of the entry. - :param basestring entry_id: the id of the entry. - :param basestring path: a path to a specific resource. - :return: the content of the file - :rtype: bytes - """ - resource_relative_path = os.path.join(entry_type, entry_id, path or '') - resource = os.path.join(self.directory, resource_relative_path) - if not os.path.exists(resource): - raise StorageError("Resource {0} does not exist".format(resource_relative_path)) - if not os.path.isfile(resource): - resources = os.listdir(resource) - if len(resources) != 1: - raise StorageError('No resource in path: {0}'.format(resource)) - resource = os.path.join(resource, resources[0]) - with open(resource, 'rb') as resource_file: - return resource_file.read() - - def download(self, entry_type, entry_id, destination, path=None): - """ - Download a specific file or dir from the file system resource storage. - - :param basestring entry_type: the name of the entry. - :param basestring entry_id: the id of the entry - :param basestring destination: the destination of the files. - :param basestring path: a path on the remote machine relative to the root of the entry. - """ - resource_relative_path = os.path.join(entry_type, entry_id, path or '') - resource = os.path.join(self.directory, resource_relative_path) - if not os.path.exists(resource): - raise StorageError("Resource {0} does not exist".format(resource_relative_path)) - if os.path.isfile(resource): - shutil.copy2(resource, destination) - else: - distutils.dir_util.copy_tree(resource, destination) # pylint: disable=no-member - - def upload(self, entry_type, entry_id, source, path=None): - """ - Uploads a specific file or dir to the file system resource storage. - - :param basestring entry_type: the name of the entry. - :param basestring entry_id: the id of the entry - :param source: the source of the files to upload. - :param path: the destination of the file/s relative to the entry root dir. - """ - resource_directory = os.path.join(self.directory, entry_type, entry_id) - if not os.path.exists(resource_directory): - os.makedirs(resource_directory) - destination = os.path.join(resource_directory, path or '') - if os.path.isfile(source): - shutil.copy2(source, destination) - else: - distutils.dir_util.copy_tree(source, destination) # pylint: disable=no-member diff --git a/aria/storage/exceptions.py b/aria/storage/exceptions.py index 22dfc507..f982f63b 100644 --- a/aria/storage/exceptions.py +++ b/aria/storage/exceptions.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +Storage based exceptions +""" from .. import exceptions diff --git a/aria/storage/filesystem_rapi.py b/aria/storage/filesystem_rapi.py new file mode 100644 index 00000000..f810f581 --- /dev/null +++ b/aria/storage/filesystem_rapi.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +SQLalchemy based RAPI +""" +import os +import shutil +from contextlib import contextmanager +from functools import partial +from distutils import dir_util # https://github.com/PyCQA/pylint/issues/73; pylint: disable=no-name-in-module +from multiprocessing import RLock + +from aria.storage import ( + api, + exceptions +) + + +class FileSystemResourceAPI(api.ResourceAPI): + """ + File system resource storage. + """ + + def __init__(self, directory, **kwargs): + """ + File system implementation for storage api. + :param str directory: root dir for storage. + """ + super(FileSystemResourceAPI, self).__init__(**kwargs) + self.directory = directory + self.base_path = os.path.join(self.directory, self.name) + self._join_path = partial(os.path.join, self.base_path) + self._lock = RLock() + + @contextmanager + def connect(self): + """ + Established a connection and destroys it after use. + :return: + """ + try: + self._establish_connection() + yield self + except BaseException as e: + raise exceptions.StorageError(str(e)) + finally: + self._destroy_connection() + + def _establish_connection(self): + """ + Establish a conenction. used in the 'connect' contextmanager. + :return: + """ + self._lock.acquire() + + + def _destroy_connection(self): + """ + Destroy a connection. used in the 'connect' contextmanager. + :return: + """ + self._lock.release() + + def __repr__(self): + return '{cls.__name__}(directory={self.directory})'.format( + cls=self.__class__, self=self) + + def create(self, **kwargs): + """ + Create directory in storage by path. + tries to create the root directory as well. + :param str name: path of file in storage. + """ + try: + os.makedirs(self.directory) + except (OSError, IOError): + pass + os.makedirs(self.base_path) + + def read(self, entry_id, path=None, **_): + """ + Retrieve the content of a file system storage resource. + + :param str entry_type: the type of the entry. + :param str entry_id: the id of the entry. + :param str path: a path to a specific resource. + :return: the content of the file + :rtype: bytes + """ + resource_relative_path = os.path.join(self.name, entry_id, path or '') + resource = os.path.join(self.directory, resource_relative_path) + if not os.path.exists(resource): + raise exceptions.StorageError("Resource {0} does not exist". + format(resource_relative_path)) + if not os.path.isfile(resource): + resources = os.listdir(resource) + if len(resources) != 1: + raise exceptions.StorageError('No resource in path: {0}'.format(resource)) + resource = os.path.join(resource, resources[0]) + with open(resource, 'rb') as resource_file: + return resource_file.read() + + def download(self, entry_id, destination, path=None, **_): + """ + Download a specific file or dir from the file system resource storage. + + :param str entry_type: the name of the entry. + :param str entry_id: the id of the entry + :param str destination: the destination of the files. + :param str path: a path on the remote machine relative to the root of the entry. + """ + resource_relative_path = os.path.join(self.name, entry_id, path or '') + resource = os.path.join(self.directory, resource_relative_path) + if not os.path.exists(resource): + raise exceptions.StorageError("Resource {0} does not exist". + format(resource_relative_path)) + if os.path.isfile(resource): + shutil.copy2(resource, destination) + else: + dir_util.copy_tree(resource, destination) # pylint: disable=no-member + + def upload(self, entry_id, source, path=None, **_): + """ + Uploads a specific file or dir to the file system resource storage. + + :param str entry_type: the name of the entry. + :param str entry_id: the id of the entry + :param source: the source of the files to upload. + :param path: the destination of the file/s relative to the entry root dir. + """ + resource_directory = os.path.join(self.directory, self.name, entry_id) + if not os.path.exists(resource_directory): + os.makedirs(resource_directory) + destination = os.path.join(resource_directory, path or '') + if os.path.isfile(source): + shutil.copy2(source, destination) + else: + dir_util.copy_tree(source, destination) # pylint: disable=no-member diff --git a/aria/storage/models.py b/aria/storage/models.py index d24ad753..6302e66e 100644 --- a/aria/storage/models.py +++ b/aria/storage/models.py @@ -36,16 +36,30 @@ * ProviderContext - provider context implementation model. * Plugin - plugin implementation model. """ - +from collections import namedtuple from datetime import datetime -from types import NoneType -from .structures import Field, IterPointerField, Model, uuid_generator, PointerField +from sqlalchemy.ext.declarative.base import declared_attr + +from .structures import ( + SQLModelBase, + Column, + Integer, + Text, + DateTime, + Boolean, + Enum, + String, + Float, + List, + Dict, + foreign_key, + one_to_many_relationship, + relationship_to_self, + orm) __all__ = ( - 'Model', 'Blueprint', - 'Snapshot', 'Deployment', 'DeploymentUpdateStep', 'DeploymentUpdate', @@ -59,66 +73,192 @@ 'Plugin', ) -# todo: sort this, maybe move from mgr or move from aria??? -ACTION_TYPES = () -ENTITY_TYPES = () + +#pylint: disable=no-self-argument -class Blueprint(Model): +class Blueprint(SQLModelBase): """ - A Model which represents a blueprint + Blueprint model representation. """ - plan = Field(type=dict) - id = Field(type=basestring, default=uuid_generator) - description = Field(type=(basestring, NoneType)) - created_at = Field(type=datetime) - updated_at = Field(type=datetime) - main_file_name = Field(type=basestring) + __tablename__ = 'blueprints' + name = Column(Text, index=True) + created_at = Column(DateTime, nullable=False, index=True) + main_file_name = Column(Text, nullable=False) + plan = Column(Dict, nullable=False) + updated_at = Column(DateTime) + description = Column(Text) -class Snapshot(Model): + +class Deployment(SQLModelBase): """ - A Model which represents a snapshot + Deployment model representation. """ - CREATED = 'created' + __tablename__ = 'deployments' + + _private_fields = ['blueprint_id'] + + blueprint_id = foreign_key(Blueprint.id) + + name = Column(Text, index=True) + created_at = Column(DateTime, nullable=False, index=True) + description = Column(Text) + inputs = Column(Dict) + groups = Column(Dict) + permalink = Column(Text) + policy_triggers = Column(Dict) + policy_types = Column(Dict) + outputs = Column(Dict) + scaling_groups = Column(Dict) + updated_at = Column(DateTime) + workflows = Column(Dict) + + @declared_attr + def blueprint(cls): + return one_to_many_relationship(cls, Blueprint, cls.blueprint_id) + + +class Execution(SQLModelBase): + """ + Execution model representation. + """ + __tablename__ = 'executions' + + TERMINATED = 'terminated' FAILED = 'failed' - CREATING = 'creating' - UPLOADED = 'uploaded' - END_STATES = [CREATED, FAILED, UPLOADED] + CANCELLED = 'cancelled' + PENDING = 'pending' + STARTED = 'started' + CANCELLING = 'cancelling' + FORCE_CANCELLING = 'force_cancelling' - id = Field(type=basestring, default=uuid_generator) - created_at = Field(type=datetime) - status = Field(type=basestring) - error = Field(type=basestring, default=None) + STATES = [TERMINATED, FAILED, CANCELLED, PENDING, STARTED, CANCELLING, FORCE_CANCELLING] + END_STATES = [TERMINATED, FAILED, CANCELLED] + ACTIVE_STATES = [state for state in STATES if state not in END_STATES] + VALID_TRANSITIONS = { + PENDING: [STARTED, CANCELLED], + STARTED: END_STATES + [CANCELLING], + CANCELLING: END_STATES + } -class Deployment(Model): + @orm.validates('status') + def validate_status(self, key, value): + """Validation function that verifies execution status transitions are OK""" + try: + current_status = getattr(self, key) + except AttributeError: + return + valid_transitions = Execution.VALID_TRANSITIONS.get(current_status, []) + if all([current_status is not None, + current_status != value, + value not in valid_transitions]): + raise ValueError('Cannot change execution status from {current} to {new}'.format( + current=current_status, + new=value)) + return value + + deployment_id = foreign_key(Deployment.id) + blueprint_id = foreign_key(Blueprint.id) + _private_fields = ['deployment_id', 'blueprint_id'] + + created_at = Column(DateTime, index=True) + started_at = Column(DateTime, nullable=True, index=True) + ended_at = Column(DateTime, nullable=True, index=True) + error = Column(Text, nullable=True) + is_system_workflow = Column(Boolean, nullable=False, default=False) + parameters = Column(Dict) + status = Column(Enum(*STATES, name='execution_status'), default=PENDING) + workflow_name = Column(Text, nullable=False) + + @declared_attr + def deployment(cls): + return one_to_many_relationship(cls, Deployment, cls.deployment_id) + + @declared_attr + def blueprint(cls): + return one_to_many_relationship(cls, Blueprint, cls.blueprint_id) + + def __str__(self): + return '<{0} id=`{1}` (status={2})>'.format( + self.__class__.__name__, + self.id, + self.status + ) + + +class DeploymentUpdate(SQLModelBase): """ - A Model which represents a deployment + Deployment update model representation. """ - id = Field(type=basestring, default=uuid_generator) - description = Field(type=(basestring, NoneType)) - created_at = Field(type=datetime) - updated_at = Field(type=datetime) - blueprint_id = Field(type=basestring) - workflows = Field(type=dict) - inputs = Field(type=dict, default=lambda: {}) - policy_types = Field(type=dict, default=lambda: {}) - policy_triggers = Field(type=dict, default=lambda: {}) - groups = Field(type=dict, default=lambda: {}) - outputs = Field(type=dict, default=lambda: {}) - scaling_groups = Field(type=dict, default=lambda: {}) - - -class DeploymentUpdateStep(Model): + __tablename__ = 'deployment_updates' + + deployment_id = foreign_key(Deployment.id) + execution_id = foreign_key(Execution.id, nullable=True) + _private_fields = ['execution_id', 'deployment_id'] + + created_at = Column(DateTime, nullable=False, index=True) + deployment_plan = Column(Dict, nullable=False) + deployment_update_node_instances = Column(Dict) + deployment_update_deployment = Column(Dict) + deployment_update_nodes = Column(Dict) + modified_entity_ids = Column(Dict) + state = Column(Text) + + @declared_attr + def execution(cls): + return one_to_many_relationship(cls, Execution, cls.execution_id) + + @declared_attr + def deployment(cls): + return one_to_many_relationship(cls, Deployment, cls.deployment_id) + + def to_dict(self, suppress_error=False, **kwargs): + dep_update_dict = super(DeploymentUpdate, self).to_dict(suppress_error) + # Taking care of the fact the DeploymentSteps are objects + dep_update_dict['steps'] = [step.to_dict() for step in self.steps] + return dep_update_dict + + +class DeploymentUpdateStep(SQLModelBase): """ - A Model which represents a deployment update step + Deployment update step model representation. """ - id = Field(type=basestring, default=uuid_generator) - action = Field(type=basestring, choices=ACTION_TYPES) - entity_type = Field(type=basestring, choices=ENTITY_TYPES) - entity_id = Field(type=basestring) - supported = Field(type=bool, default=True) + __tablename__ = 'deployment_update_steps' + _action_types = namedtuple('ACTION_TYPES', 'ADD, REMOVE, MODIFY') + ACTION_TYPES = _action_types(ADD='add', REMOVE='remove', MODIFY='modify') + _entity_types = namedtuple( + 'ENTITY_TYPES', + 'NODE, RELATIONSHIP, PROPERTY, OPERATION, WORKFLOW, OUTPUT, DESCRIPTION, GROUP, ' + 'POLICY_TYPE, POLICY_TRIGGER, PLUGIN') + ENTITY_TYPES = _entity_types( + NODE='node', + RELATIONSHIP='relationship', + PROPERTY='property', + OPERATION='operation', + WORKFLOW='workflow', + OUTPUT='output', + DESCRIPTION='description', + GROUP='group', + POLICY_TYPE='policy_type', + POLICY_TRIGGER='policy_trigger', + PLUGIN='plugin' + ) + + deployment_update_id = foreign_key(DeploymentUpdate.id) + _private_fields = ['deployment_update_id'] + + action = Column(Enum(*ACTION_TYPES, name='action_type'), nullable=False) + entity_id = Column(Text, nullable=False) + entity_type = Column(Enum(*ENTITY_TYPES, name='entity_type'), nullable=False) + + @declared_attr + def deployment_update(cls): + return one_to_many_relationship(cls, + DeploymentUpdate, + cls.deployment_update_id, + backreference='steps') def __hash__(self): return hash((self.id, self.entity_id)) @@ -148,265 +288,225 @@ def __lt__(self, other): return False -class DeploymentUpdate(Model): +class DeploymentModification(SQLModelBase): """ - A Model which represents a deployment update + Deployment modification model representation. """ - INITIALIZING = 'initializing' - SUCCESSFUL = 'successful' - UPDATING = 'updating' - FINALIZING = 'finalizing' - EXECUTING_WORKFLOW = 'executing_workflow' - FAILED = 'failed' + __tablename__ = 'deployment_modifications' - STATES = [ - INITIALIZING, - SUCCESSFUL, - UPDATING, - FINALIZING, - EXECUTING_WORKFLOW, - FAILED, - ] - - # '{0}-{1}'.format(kwargs['deployment_id'], uuid4()) - id = Field(type=basestring, default=uuid_generator) - deployment_id = Field(type=basestring) - state = Field(type=basestring, choices=STATES, default=INITIALIZING) - deployment_plan = Field() - deployment_update_nodes = Field(default=None) - deployment_update_node_instances = Field(default=None) - deployment_update_deployment = Field(default=None) - modified_entity_ids = Field(default=None) - execution_id = Field(type=basestring) - steps = IterPointerField(type=DeploymentUpdateStep, default=()) - - -class Execution(Model): - """ - A Model which represents an execution - """ + STARTED = 'started' + FINISHED = 'finished' + ROLLEDBACK = 'rolledback' - class _Validation(object): - - @staticmethod - def execution_status_transition_validation(_, value, instance): - """Validation function that verifies execution status transitions are OK""" - try: - current_status = instance.status - except AttributeError: - return - valid_transitions = Execution.VALID_TRANSITIONS.get(current_status, []) - if current_status != value and value not in valid_transitions: - raise ValueError('Cannot change execution status from {current} to {new}'.format( - current=current_status, - new=value)) + STATES = [STARTED, FINISHED, ROLLEDBACK] + END_STATES = [FINISHED, ROLLEDBACK] - TERMINATED = 'terminated' - FAILED = 'failed' - CANCELLED = 'cancelled' - PENDING = 'pending' - STARTED = 'started' - CANCELLING = 'cancelling' - STATES = ( - TERMINATED, - FAILED, - CANCELLED, - PENDING, - STARTED, - CANCELLING, - ) - END_STATES = [TERMINATED, FAILED, CANCELLED] - ACTIVE_STATES = [state for state in STATES if state not in END_STATES] - VALID_TRANSITIONS = { - PENDING: [STARTED, CANCELLED], - STARTED: END_STATES + [CANCELLING], - CANCELLING: END_STATES - } + deployment_id = foreign_key(Deployment.id) + _private_fields = ['deployment_id'] - id = Field(type=basestring, default=uuid_generator) - status = Field(type=basestring, choices=STATES, - validation_func=_Validation.execution_status_transition_validation) - deployment_id = Field(type=basestring) - workflow_id = Field(type=basestring) - blueprint_id = Field(type=basestring) - created_at = Field(type=datetime, default=datetime.utcnow) - started_at = Field(type=datetime, default=None) - ended_at = Field(type=datetime, default=None) - error = Field(type=basestring, default=None) - parameters = Field() + context = Column(Dict) + created_at = Column(DateTime, nullable=False, index=True) + ended_at = Column(DateTime, index=True) + modified_nodes = Column(Dict) + node_instances = Column(Dict) + status = Column(Enum(*STATES, name='deployment_modification_status')) + @declared_attr + def deployment(cls): + return one_to_many_relationship(cls, + Deployment, + cls.deployment_id, + backreference='modifications') -class Relationship(Model): + +class Node(SQLModelBase): """ - A Model which represents a relationship + Node model representation. """ - id = Field(type=basestring, default=uuid_generator) - source_id = Field(type=basestring) - target_id = Field(type=basestring) - source_interfaces = Field(type=dict) - source_operations = Field(type=dict) - target_interfaces = Field(type=dict) - target_operations = Field(type=dict) - type = Field(type=basestring) - type_hierarchy = Field(type=list) - properties = Field(type=dict) - - -class Node(Model): + __tablename__ = 'nodes' + + # See base class for an explanation on these properties + is_id_unique = False + + name = Column(Text, index=True) + _private_fields = ['deployment_id', 'host_id'] + deployment_id = foreign_key(Deployment.id) + host_id = foreign_key('nodes.id', nullable=True) + + @declared_attr + def deployment(cls): + return one_to_many_relationship(cls, Deployment, cls.deployment_id) + + deploy_number_of_instances = Column(Integer, nullable=False) + # TODO: This probably should be a foreign key, but there's no guarantee + # in the code, currently, that the host will be created beforehand + max_number_of_instances = Column(Integer, nullable=False) + min_number_of_instances = Column(Integer, nullable=False) + number_of_instances = Column(Integer, nullable=False) + planned_number_of_instances = Column(Integer, nullable=False) + plugins = Column(Dict) + plugins_to_install = Column(Dict) + properties = Column(Dict) + operations = Column(Dict) + type = Column(Text, nullable=False, index=True) + type_hierarchy = Column(List) + + @declared_attr + def host(cls): + return relationship_to_self(cls, cls.host_id, cls.id) + + +class Relationship(SQLModelBase): """ - A Model which represents a node + Relationship model representation. """ - id = Field(type=basestring, default=uuid_generator) - blueprint_id = Field(type=basestring) - type = Field(type=basestring) - type_hierarchy = Field() - number_of_instances = Field(type=int) - planned_number_of_instances = Field(type=int) - deploy_number_of_instances = Field(type=int) - host_id = Field(type=basestring, default=None) - properties = Field(type=dict) - operations = Field(type=dict) - plugins = Field(type=list, default=()) - relationships = IterPointerField(type=Relationship) - plugins_to_install = Field(type=list, default=()) - min_number_of_instances = Field(type=int) - max_number_of_instances = Field(type=int) - - def relationships_by_target(self, target_id): - """ - Retreives all of the relationship by target. - :param target_id: the node id of the target of the relationship - :yields: a relationship which target and node with the specified target_id - """ - for relationship in self.relationships: - if relationship.target_id == target_id: - yield relationship - # todo: maybe add here Exception if isn't exists (didn't yield one's) + __tablename__ = 'relationships' + _private_fields = ['source_node_id', 'target_node_id'] -class RelationshipInstance(Model): - """ - A Model which represents a relationship instance - """ - id = Field(type=basestring, default=uuid_generator) - target_id = Field(type=basestring) - target_name = Field(type=basestring) - source_id = Field(type=basestring) - source_name = Field(type=basestring) - type = Field(type=basestring) - relationship = PointerField(type=Relationship) + source_node_id = foreign_key(Node.id) + target_node_id = foreign_key(Node.id) + + @declared_attr + def source_node(cls): + return one_to_many_relationship(cls, + Node, + cls.source_node_id, + 'outbound_relationships') + + @declared_attr + def target_node(cls): + return one_to_many_relationship(cls, + Node, + cls.target_node_id, + 'inbound_relationships') + source_interfaces = Column(Dict) + source_operations = Column(Dict, nullable=False) + target_interfaces = Column(Dict) + target_operations = Column(Dict, nullable=False) + type = Column(String, nullable=False) + type_hierarchy = Column(List) + properties = Column(Dict) -class NodeInstance(Model): + +class NodeInstance(SQLModelBase): """ - A Model which represents a node instance + Node instance model representation. """ - # todo: add statuses - UNINITIALIZED = 'uninitialized' - INITIALIZING = 'initializing' - CREATING = 'creating' - CONFIGURING = 'configuring' - STARTING = 'starting' - DELETED = 'deleted' - STOPPING = 'stopping' - DELETING = 'deleting' - STATES = ( - UNINITIALIZED, - INITIALIZING, - CREATING, - CONFIGURING, - STARTING, - DELETED, - STOPPING, - DELETING - ) + __tablename__ = 'node_instances' - id = Field(type=basestring, default=uuid_generator) - deployment_id = Field(type=basestring) - runtime_properties = Field(type=dict) - state = Field(type=basestring, choices=STATES, default=UNINITIALIZED) - version = Field(type=(basestring, NoneType)) - relationship_instances = IterPointerField(type=RelationshipInstance) - node = PointerField(type=Node) - host_id = Field(type=basestring, default=None) - scaling_groups = Field(default=()) - - def relationships_by_target(self, target_id): - """ - Retreives all of the relationship by target. - :param target_id: the instance id of the target of the relationship - :yields: a relationship instance which target and node with the specified target_id - """ - for relationship_instance in self.relationship_instances: - if relationship_instance.target_id == target_id: - yield relationship_instance - # todo: maybe add here Exception if isn't exists (didn't yield one's) + node_id = foreign_key(Node.id) + deployment_id = foreign_key(Deployment.id) + host_id = foreign_key('node_instances.id', nullable=True) + + _private_fields = ['node_id', 'host_id'] + + name = Column(Text, index=True) + runtime_properties = Column(Dict) + scaling_groups = Column(Dict) + state = Column(Text, nullable=False) + version = Column(Integer, default=1) + + @declared_attr + def deployment(cls): + return one_to_many_relationship(cls, Deployment, cls.deployment_id) + + @declared_attr + def node(cls): + return one_to_many_relationship(cls, Node, cls.node_id) + @declared_attr + def host(cls): + return relationship_to_self(cls, cls.host_id, cls.id) -class DeploymentModification(Model): + +class RelationshipInstance(SQLModelBase): """ - A Model which represents a deployment modification + Relationship instance model representation. """ - STARTED = 'started' - FINISHED = 'finished' - ROLLEDBACK = 'rolledback' - END_STATES = [FINISHED, ROLLEDBACK] + __tablename__ = 'relationship_instances' + + relationship_id = foreign_key(Relationship.id) + source_node_instance_id = foreign_key(NodeInstance.id) + target_node_instance_id = foreign_key(NodeInstance.id) + + _private_fields = ['relationship_storage_id', + 'source_node_instance_id', + 'target_node_instance_id'] - id = Field(type=basestring, default=uuid_generator) - deployment_id = Field(type=basestring) - modified_nodes = Field(type=(dict, NoneType)) - added_and_related = IterPointerField(type=NodeInstance) - removed_and_related = IterPointerField(type=NodeInstance) - extended_and_related = IterPointerField(type=NodeInstance) - reduced_and_related = IterPointerField(type=NodeInstance) - # before_modification = IterPointerField(type=NodeInstance) - status = Field(type=basestring, choices=(STARTED, FINISHED, ROLLEDBACK)) - created_at = Field(type=datetime) - ended_at = Field(type=(datetime, NoneType)) - context = Field() - - -class ProviderContext(Model): + @declared_attr + def source_node_instance(cls): + return one_to_many_relationship(cls, + NodeInstance, + cls.source_node_instance_id, + 'outbound_relationship_instances') + + @declared_attr + def target_node_instance(cls): + return one_to_many_relationship(cls, + NodeInstance, + cls.target_node_instance_id, + 'inbound_relationship_instances') + + @declared_attr + def relationship(cls): + return one_to_many_relationship(cls, Relationship, cls.relationship_id) + + +class ProviderContext(SQLModelBase): """ - A Model which represents a provider context + Provider context model representation. """ - id = Field(type=basestring, default=uuid_generator) - context = Field(type=dict) - name = Field(type=basestring) + __tablename__ = 'provider_context' + + name = Column(Text, nullable=False) + context = Column(Dict, nullable=False) -class Plugin(Model): +class Plugin(SQLModelBase): """ - A Model which represents a plugin + Plugin model representation. """ - id = Field(type=basestring, default=uuid_generator) - package_name = Field(type=basestring) - archive_name = Field(type=basestring) - package_source = Field(type=dict) - package_version = Field(type=basestring) - supported_platform = Field(type=basestring) - distribution = Field(type=basestring) - distribution_version = Field(type=basestring) - distribution_release = Field(type=basestring) - wheels = Field() - excluded_wheels = Field() - supported_py_versions = Field(type=list) - uploaded_at = Field(type=datetime) - - -class Task(Model): + __tablename__ = 'plugins' + + archive_name = Column(Text, nullable=False, index=True) + distribution = Column(Text) + distribution_release = Column(Text) + distribution_version = Column(Text) + excluded_wheels = Column(Dict) + package_name = Column(Text, nullable=False, index=True) + package_source = Column(Text) + package_version = Column(Text) + supported_platform = Column(Dict) + supported_py_versions = Column(Dict) + uploaded_at = Column(DateTime, nullable=False, index=True) + wheels = Column(Dict, nullable=False) + + +class Task(SQLModelBase): """ A Model which represents an task """ - class _Validation(object): + __tablename__ = 'task' + node_instance_id = foreign_key(NodeInstance.id, nullable=True) + relationship_instance_id = foreign_key(RelationshipInstance.id, nullable=True) + execution_id = foreign_key(Execution.id, nullable=True) + + _private_fields = ['node_instance_id', + 'relationship_instance_id', + 'execution_id'] - @staticmethod - def validate_max_attempts(_, value, *args): - """Validates that max attempts is either -1 or a positive number""" - if value < 1 and value != Task.INFINITE_RETRIES: - raise ValueError('Max attempts can be either -1 (infinite) or any positive number. ' - 'Got {value}'.format(value=value)) + @declared_attr + def node_instance(cls): + return one_to_many_relationship(cls, NodeInstance, cls.node_instance_id) + + @declared_attr + def relationship_instance(cls): + return one_to_many_relationship(cls, + RelationshipInstance, + cls.relationship_instance_id) PENDING = 'pending' RETRYING = 'retrying' @@ -422,23 +522,51 @@ def validate_max_attempts(_, value, *args): SUCCESS, FAILED, ) + WAIT_STATES = [PENDING, RETRYING] END_STATES = [SUCCESS, FAILED] + + @orm.validates('max_attempts') + def validate_max_attempts(self, _, value): # pylint: disable=no-self-use + """Validates that max attempts is either -1 or a positive number""" + if value < 1 and value != Task.INFINITE_RETRIES: + raise ValueError('Max attempts can be either -1 (infinite) or any positive number. ' + 'Got {value}'.format(value=value)) + return value + INFINITE_RETRIES = -1 - id = Field(type=basestring, default=uuid_generator) - status = Field(type=basestring, choices=STATES, default=PENDING) - execution_id = Field(type=basestring) - due_at = Field(type=datetime, default=datetime.utcnow) - started_at = Field(type=datetime, default=None) - ended_at = Field(type=datetime, default=None) - max_attempts = Field(type=int, default=1, validation_func=_Validation.validate_max_attempts) - retry_count = Field(type=int, default=0) - retry_interval = Field(type=(int, float), default=0) - ignore_failure = Field(type=bool, default=False) + status = Column(Enum(*STATES), name='status', default=PENDING) + + due_at = Column(DateTime, default=datetime.utcnow) + started_at = Column(DateTime, default=None) + ended_at = Column(DateTime, default=None) + max_attempts = Column(Integer, default=1) + retry_count = Column(Integer, default=0) + retry_interval = Column(Float, default=0) + ignore_failure = Column(Boolean, default=False) # Operation specific fields - name = Field(type=basestring) - operation_mapping = Field(type=basestring) - actor = Field() - inputs = Field(type=dict, default=lambda: {}) + name = Column(String) + operation_mapping = Column(String) + inputs = Column(Dict) + + @declared_attr + def execution(cls): + return one_to_many_relationship(cls, Execution, cls.execution_id) + + @property + def actor(self): + """ + Return the actor of the task + :return: + """ + return self.node_instance or self.relationship_instance + + @classmethod + def as_node_instance(cls, instance_id, **kwargs): + return cls(node_instance_id=instance_id, **kwargs) + + @classmethod + def as_relationship_instance(cls, instance_id, **kwargs): + return cls(relationship_instance_id=instance_id, **kwargs) diff --git a/aria/storage/sql_mapi.py b/aria/storage/sql_mapi.py new file mode 100644 index 00000000..cde40c2c --- /dev/null +++ b/aria/storage/sql_mapi.py @@ -0,0 +1,382 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +SQLAlchemy based MAPI +""" + +from sqlalchemy.exc import SQLAlchemyError + +from aria.utils.collections import OrderedDict +from aria.storage import ( + api, + exceptions +) + + +class SQLAlchemyModelAPI(api.ModelAPI): + """ + SQL based MAPI. + """ + + def __init__(self, + engine, + session, + **kwargs): + super(SQLAlchemyModelAPI, self).__init__(**kwargs) + self._engine = engine + self._session = session + + def get(self, entry_id, include=None, **kwargs): + """Return a single result based on the model class and element ID + """ + query = self._get_query(include, {'id': entry_id}) + result = query.first() + + if not result: + raise exceptions.StorageError( + 'Requested {0} with ID `{1}` was not found' + .format(self.model_cls.__name__, entry_id) + ) + return result + + def get_by_name(self, entry_name, include=None, **kwargs): + assert hasattr(self.model_cls, 'name') + result = self.list(include=include, filters={'name': entry_name}) + if not result: + raise exceptions.StorageError( + 'Requested {0} with NAME `{1}` was not found' + .format(self.model_cls.__name__, entry_name) + ) + elif len(result) > 1: + raise exceptions.StorageError( + 'Requested {0} with NAME `{1}` returned more than 1 value' + .format(self.model_cls.__name__, entry_name) + ) + else: + return result[0] + + def list(self, + include=None, + filters=None, + pagination=None, + sort=None, + **kwargs): + query = self._get_query(include, filters, sort) + + results, total, size, offset = self._paginate(query, pagination) + + return ListResult( + items=results, + metadata=dict(total=total, + size=size, + offset=offset) + ) + + def iter(self, + include=None, + filters=None, + sort=None, + **kwargs): + """Return a (possibly empty) list of `model_class` results + """ + return iter(self._get_query(include, filters, sort)) + + def put(self, entry, **kwargs): + """Create a `model_class` instance from a serializable `model` object + + :param entry: A dict with relevant kwargs, or an instance of a class + that has a `to_dict` method, and whose attributes match the columns + of `model_class` (might also my just an instance of `model_class`) + :return: An instance of `model_class` + """ + self._session.add(entry) + self._safe_commit() + return entry + + def delete(self, entry, **kwargs): + """Delete a single result based on the model class and element ID + """ + self._load_relationships(entry) + self._session.delete(entry) + self._safe_commit() + return entry + + def update(self, entry, **kwargs): + """Add `instance` to the DB session, and attempt to commit + + :return: The updated instance + """ + return self.put(entry) + + def refresh(self, entry): + """Reload the instance with fresh information from the DB + + :param entry: Instance to be re-loaded from the DB + :return: The refreshed instance + """ + self._session.refresh(entry) + self._load_relationships(entry) + return entry + + def _destroy_connection(self): + pass + + def _establish_connection(self): + pass + + def create(self, checkfirst=True, **kwargs): + self.model_cls.__table__.create(self._engine, checkfirst=checkfirst) + + def drop(self): + """ + Drop the table from the storage. + :return: + """ + self.model_cls.__table__.drop(self._engine) + + def _safe_commit(self): + """Try to commit changes in the session. Roll back if exception raised + Excepts SQLAlchemy errors and rollbacks if they're caught + """ + try: + self._session.commit() + except (SQLAlchemyError, ValueError) as e: + self._session.rollback() + raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) + + def _get_base_query(self, include, joins): + """Create the initial query from the model class and included columns + + :param include: A (possibly empty) list of columns to include in + the query + :return: An SQLAlchemy AppenderQuery object + """ + # If only some columns are included, query through the session object + if include: + # Make sure that attributes come before association proxies + include.sort(key=lambda x: x.is_clause_element) + query = self._session.query(*include) + else: + # If all columns should be returned, query directly from the model + query = self._session.query(self.model_cls) + + if not self._skip_joining(joins, include): + for join_table in joins: + query = query.join(join_table) + + return query + + @staticmethod + def _get_joins(model_class, columns): + """Get a list of all the tables on which we need to join + + :param columns: A set of all columns involved in the query + """ + joins = [] # Using a list instead of a set because order is important + for column_name in columns: + column = getattr(model_class, column_name) + while not column.is_attribute: + column = column.remote_attr + if column.is_attribute: + join_class = column.class_ + else: + join_class = column.local_attr.class_ + + # Don't add the same class more than once + if join_class not in joins: + joins.append(join_class) + return joins + + @staticmethod + def _skip_joining(joins, include): + """Dealing with an edge case where the only included column comes from + an other table. In this case, we mustn't join on the same table again + + :param joins: A list of tables on which we're trying to join + :param include: The list of + :return: True if we need to skip joining + """ + if not joins: + return True + join_table_names = [t.__tablename__ for t in joins] + + if len(include) != 1: + return False + + column = include[0] + if column.is_clause_element: + table_name = column.element.table.name + else: + table_name = column.class_.__tablename__ + return table_name in join_table_names + + @staticmethod + def _sort_query(query, sort=None): + """Add sorting clauses to the query + + :param query: Base SQL query + :param sort: An optional dictionary where keys are column names to + sort by, and values are the order (asc/desc) + :return: An SQLAlchemy AppenderQuery object + """ + if sort: + for column, order in sort.items(): + if order == 'desc': + column = column.desc() + query = query.order_by(column) + return query + + def _filter_query(self, query, filters): + """Add filter clauses to the query + + :param query: Base SQL query + :param filters: An optional dictionary where keys are column names to + filter by, and values are values applicable for those columns (or lists + of such values) + :return: An SQLAlchemy AppenderQuery object + """ + return self._add_value_filter(query, filters) + + @staticmethod + def _add_value_filter(query, filters): + for column, value in filters.items(): + if isinstance(value, (list, tuple)): + query = query.filter(column.in_(value)) + else: + query = query.filter(column == value) + + return query + + def _get_query(self, + include=None, + filters=None, + sort=None): + """Get an SQL query object based on the params passed + + :param model_class: SQL DB table class + :param include: An optional list of columns to include in the query + :param filters: An optional dictionary where keys are column names to + filter by, and values are values applicable for those columns (or lists + of such values) + :param sort: An optional dictionary where keys are column names to + sort by, and values are the order (asc/desc) + :return: A sorted and filtered query with only the relevant + columns + """ + include, filters, sort, joins = self._get_joins_and_converted_columns( + include, filters, sort + ) + + query = self._get_base_query(include, joins) + query = self._filter_query(query, filters) + query = self._sort_query(query, sort) + return query + + def _get_joins_and_converted_columns(self, + include, + filters, + sort): + """Get a list of tables on which we need to join and the converted + `include`, `filters` and `sort` arguments (converted to actual SQLA + column/label objects instead of column names) + """ + include = include or [] + filters = filters or dict() + sort = sort or OrderedDict() + + all_columns = set(include) | set(filters.keys()) | set(sort.keys()) + joins = self._get_joins(self.model_cls, all_columns) + + include, filters, sort = self._get_columns_from_field_names( + include, filters, sort + ) + return include, filters, sort, joins + + def _get_columns_from_field_names(self, + include, + filters, + sort): + """Go over the optional parameters (include, filters, sort), and + replace column names with actual SQLA column objects + """ + include = [self._get_column(c) for c in include] + filters = dict((self._get_column(c), filters[c]) for c in filters) + sort = OrderedDict((self._get_column(c), sort[c]) for c in sort) + + return include, filters, sort + + def _get_column(self, column_name): + """Return the column on which an action (filtering, sorting, etc.) + would need to be performed. Can be either an attribute of the class, + or an association proxy linked to a relationship the class has + """ + column = getattr(self.model_cls, column_name) + if column.is_attribute: + return column + else: + # We need to get to the underlying attribute, so we move on to the + # next remote_attr until we reach one + while not column.remote_attr.is_attribute: + column = column.remote_attr + # Put a label on the remote attribute with the name of the column + return column.remote_attr.label(column_name) + + @staticmethod + def _paginate(query, pagination): + """Paginate the query by size and offset + + :param query: Current SQLAlchemy query object + :param pagination: An optional dict with size and offset keys + :return: A tuple with four elements: + - res ults: `size` items starting from `offset` + - the total count of items + - `size` [default: 0] + - `offset` [default: 0] + """ + if pagination: + size = pagination.get('size', 0) + offset = pagination.get('offset', 0) + total = query.order_by(None).count() # Fastest way to count + results = query.limit(size).offset(offset).all() + return results, total, size, offset + else: + results = query.all() + return results, len(results), 0, 0 + + @staticmethod + def _load_relationships(instance): + """A helper method used to overcome a problem where the relationships + that rely on joins aren't being loaded automatically + """ + for rel in instance.__mapper__.relationships: + getattr(instance, rel.key) + + +class ListResult(object): + """ + a ListResult contains results about the requested items. + """ + def __init__(self, items, metadata): + self.items = items + self.metadata = metadata + + def __len__(self): + return len(self.items) + + def __iter__(self): + return iter(self.items) + + def __getitem__(self, item): + return self.items[item] diff --git a/aria/storage/structures.py b/aria/storage/structures.py index b02366e9..8dbd2a94 100644 --- a/aria/storage/structures.py +++ b/aria/storage/structures.py @@ -27,281 +27,218 @@ * Model - abstract model implementation. """ import json -from itertools import count -from uuid import uuid4 - -from .exceptions import StorageError -from ..logger import LoggerMixin -from ..utils.validation import ValidatorMixin - -__all__ = ( - 'uuid_generator', - 'Field', - 'IterField', - 'PointerField', - 'IterPointerField', - 'Model', - 'Storage', + +from sqlalchemy.ext.mutable import Mutable +from sqlalchemy.orm import relationship, backref +from sqlalchemy.ext.declarative import declarative_base +# pylint: disable=unused-import +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy import ( + schema, + VARCHAR, + ARRAY, + Column, + Integer, + Text, + DateTime, + Boolean, + Enum, + String, + PickleType, + Float, + TypeDecorator, + ForeignKey, + orm, ) +from aria.storage import exceptions + +Model = declarative_base() -def uuid_generator(): - """ - wrapper function which generates ids - """ - return str(uuid4()) +def foreign_key(foreign_key_column, nullable=False): + """Return a ForeignKey object with the relevant -class Field(ValidatorMixin): + :param foreign_key_column: Unique id column in the parent table + :param nullable: Should the column be allowed to remain empty """ - A single field implementation + return Column( + ForeignKey(foreign_key_column, ondelete='CASCADE'), + nullable=nullable + ) + + +def one_to_many_relationship(child_class, + parent_class, + foreign_key_column, + backreference=None): + """Return a one-to-many SQL relationship object + Meant to be used from inside the *child* object + + :param parent_class: Class of the parent table + :param child_class: Class of the child table + :param foreign_key_column: The column of the foreign key + :param backreference: The name to give to the reference to the child """ - NO_DEFAULT = 'NO_DEFAULT' - - try: - # python 3 syntax - _next_id = count().__next__ - except AttributeError: - # python 2 syntax - _next_id = count().next - _ATTRIBUTE_NAME = '_cache_{0}'.format - - def __init__( - self, - type=None, - choices=(), - validation_func=None, - default=NO_DEFAULT, - **kwargs): - """ - Simple field manager. + backreference = backreference or child_class.__tablename__ + return relationship( + parent_class, + primaryjoin=lambda: parent_class.id == foreign_key_column, + # The following line make sure that when the *parent* is + # deleted, all its connected children are deleted as well + backref=backref(backreference, cascade='all') + ) - :param type: possible type of the field. - :param choices: a set of possible field values. - :param default: default field value. - :param kwargs: kwargs to be passed to next in line classes. - """ - self.type = type - self.choices = choices - self.default = default - self.validation_func = validation_func - super(Field, self).__init__(**kwargs) - - def __get__(self, instance, owner): - if instance is None: - return self - field_name = self._field_name(instance) - try: - return getattr(instance, self._ATTRIBUTE_NAME(field_name)) - except AttributeError as exc: - if self.default == self.NO_DEFAULT: - raise AttributeError( - str(exc).replace(self._ATTRIBUTE_NAME(field_name), field_name)) - - default_value = self.default() if callable(self.default) else self.default - setattr(instance, self._ATTRIBUTE_NAME(field_name), default_value) - return default_value - - def __set__(self, instance, value): - field_name = self._field_name(instance) - self.validate_value(field_name, value, instance) - setattr(instance, self._ATTRIBUTE_NAME(field_name), value) - - def validate_value(self, name, value, instance): - """ - Validates the value of the field. - :param name: the name of the field. - :param value: the value of the field. - :param instance: the instance containing the field. - """ - if self.default != self.NO_DEFAULT and value == self.default: - return - if self.type: - self.validate_instance(name, value, self.type) - if self.choices: - self.validate_in_choice(name, value, self.choices) - if self.validation_func: - self.validation_func(name, value, instance) - - def _field_name(self, instance): - """ - retrieves the field name from the instance. - - :param Field instance: the instance which holds the field. - :return: name of the field - :rtype: basestring - """ - for name, member in vars(instance.__class__).iteritems(): - if member is self: - return name +def relationship_to_self(self_cls, parent_key, self_key): + return relationship( + self_cls, + foreign_keys=parent_key, + remote_side=self_key + ) -class IterField(Field): +class _MutableType(TypeDecorator): """ - Represents an iterable field. + Dict representation of type. """ - def __init__(self, **kwargs): - """ - Simple iterable field manager. - This field type don't have choices option. - - :param kwargs: kwargs to be passed to next in line classes. - """ - super(IterField, self).__init__(choices=(), **kwargs) + @property + def python_type(self): + raise NotImplementedError - def validate_value(self, name, values, *args): - """ - Validates the value of each iterable value. + def process_literal_param(self, value, dialect): + pass - :param name: the name of the field. - :param values: the values of the field. - """ - for value in values: - self.validate_instance(name, value, self.type) + impl = VARCHAR + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value) + return value -class PointerField(Field): - """ - A single pointer field implementation. - - Any PointerField points via id to another document. - """ + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value - def __init__(self, type, **kwargs): - assert issubclass(type, Model) - super(PointerField, self).__init__(type=type, **kwargs) +class _DictType(_MutableType): + @property + def python_type(self): + return dict -class IterPointerField(IterField, PointerField): - """ - An iterable pointers field. - Any IterPointerField points via id to other documents. - """ - pass +class _ListType(_MutableType): + @property + def python_type(self): + return list -class Model(object): +class _MutableDict(Mutable, dict): """ - Base class for all of the storage models. + Enables tracking for dict values. """ - id = None + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." - def __init__(self, **fields): - """ - Abstract class for any model in the storage. - The Initializer creates attributes according to the (keyword arguments) that given - Each value is validated according to the Field. - Each model has to have and ID Field. + if not isinstance(value, _MutableDict): + if isinstance(value, dict): + return _MutableDict(value) - :param fields: each item is validated and transformed into instance attributes. - """ - self._assert_model_have_id_field(**fields) - missing_fields, unexpected_fields = self._setup_fields(fields) + # this call will raise ValueError + try: + return Mutable.coerce(key, value) + except ValueError as e: + raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) + else: + return value - if missing_fields: - raise StorageError( - 'Model {name} got missing keyword arguments: {fields}'.format( - name=self.__class__.__name__, fields=missing_fields)) + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." - if unexpected_fields: - raise StorageError( - 'Model {name} got unexpected keyword arguments: {fields}'.format( - name=self.__class__.__name__, fields=unexpected_fields)) + dict.__setitem__(self, key, value) + self.changed() - def __repr__(self): - return '{name}(fields={0})'.format(sorted(self.fields), name=self.__class__.__name__) + def __delitem__(self, key): + "Detect dictionary del events and emit change events." - def __eq__(self, other): - return ( - isinstance(other, self.__class__) and - self.fields_dict == other.fields_dict) + dict.__delitem__(self, key) + self.changed() - @property - def fields(self): - """ - Iterates over the fields of the model. - :yields: the class's field name - """ - for name, field in vars(self.__class__).items(): - if isinstance(field, Field): - yield name - @property - def fields_dict(self): - """ - Transforms the instance attributes into a dict. +class _MutableList(Mutable, list): - :return: all fields in dict format. - :rtype dict - """ - return dict((name, getattr(self, name)) for name in self.fields) + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." - @property - def json(self): - """ - Transform the dict of attributes into json - :return: - """ - return json.dumps(self.fields_dict) + if not isinstance(value, _MutableList): + if isinstance(value, list): + return _MutableList(value) - @classmethod - def _assert_model_have_id_field(cls, **fields_initializer_values): - if not getattr(cls, 'id', None): - raise StorageError('Model {cls.__name__} must have id field'.format(cls=cls)) - - if cls.id.default == cls.id.NO_DEFAULT and 'id' not in fields_initializer_values: - raise StorageError( - 'Model {cls.__name__} is missing required ' - 'keyword-only argument: "id"'.format(cls=cls)) - - def _setup_fields(self, input_fields): - missing = [] - for field_name in self.fields: + # this call will raise ValueError try: - field_obj = input_fields.pop(field_name) - setattr(self, field_name, field_obj) - except KeyError: - field = getattr(self.__class__, field_name) - if field.default == field.NO_DEFAULT: - missing.append(field_name) + return Mutable.coerce(key, value) + except ValueError as e: + raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) + else: + return value + + def __setitem__(self, key, value): + list.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + list.__delitem__(self, key) + - unexpected_fields = input_fields.keys() - return missing, unexpected_fields +Dict = _MutableDict.as_mutable(_DictType) +List = _MutableList.as_mutable(_ListType) -class Storage(LoggerMixin): +class SQLModelBase(Model): """ - Represents the storage + Abstract base class for all SQL models that allows [de]serialization """ - def __init__(self, driver, items=(), **kwargs): - super(Storage, self).__init__(**kwargs) - self.driver = driver - self.registered = {} - for item in items: - self.register(item) - self.logger.debug('{name} object is ready: {0!r}'.format( - self, name=self.__class__.__name__)) + # SQLAlchemy syntax + __abstract__ = True - def __repr__(self): - return '{name}(driver={self.driver})'.format( - name=self.__class__.__name__, self=self) + # This would be overridden once the models are created. Created for pylint. + __table__ = None + + _private_fields = [] + + id = Column(Integer, primary_key=True, autoincrement=True) - def __getattr__(self, item): - try: - return self.registered[item] - except KeyError: - return super(Storage, self).__getattribute__(item) + def to_dict(self, suppress_error=False): + """Return a dict representation of the model - def setup(self): + :param suppress_error: If set to True, sets `None` to attributes that + it's unable to retrieve (e.g., if a relationship wasn't established + yet, and so it's impossible to access a property through it) """ - Setup and create all storage items + if suppress_error: + res = dict() + for field in self.fields(): + try: + field_value = getattr(self, field) + except AttributeError: + field_value = None + res[field] = field_value + else: + # Can't simply call here `self.to_response()` because inheriting + # class might override it, but we always need the same code here + res = dict((f, getattr(self, f)) for f in self.fields()) + return res + + @classmethod + def fields(cls): + """Return the list of field names for this table + + Mostly for backwards compatibility in the code (that uses `fields`) """ - for name, api in self.registered.iteritems(): - try: - api.create() - self.logger.debug( - 'setup {name} in storage {self!r}'.format(name=name, self=self)) - except StorageError: - pass + return set(cls.__table__.columns.keys()) - set(cls._private_fields) + + def __repr__(self): + return '<{0} id=`{1}`>'.format(self.__class__.__name__, self.id) diff --git a/aria/utils/application.py b/aria/utils/application.py index b1a7fcc8..113e054b 100644 --- a/aria/utils/application.py +++ b/aria/utils/application.py @@ -117,7 +117,7 @@ def create_blueprint_storage(self, source, main_file_name=None): updated_at=now, main_file_name=main_file_name, ) - self.model_storage.blueprint.store(blueprint) + self.model_storage.blueprint.put(blueprint) self.logger.debug('created blueprint model storage entry') def create_nodes_storage(self): @@ -138,7 +138,7 @@ def create_nodes_storage(self): scalable = node_copy.pop('capabilities')['scalable']['properties'] for index, relationship in enumerate(node_copy['relationships']): relationship = self.model_storage.relationship.model_cls(**relationship) - self.model_storage.relationship.store(relationship) + self.model_storage.relationship.put(relationship) node_copy['relationships'][index] = relationship node_copy = self.model_storage.node.model_cls( @@ -149,7 +149,7 @@ def create_nodes_storage(self): max_number_of_instances=scalable['max_instances'], number_of_instances=scalable['current_instances'], **node_copy) - self.model_storage.node.store(node_copy) + self.model_storage.node.put(node_copy) def create_deployment_storage(self): """ @@ -190,7 +190,7 @@ def create_deployment_storage(self): created_at=now, updated_at=now ) - self.model_storage.deployment.store(deployment) + self.model_storage.deployment.put(deployment) self.logger.debug('created deployment model storage entry') def create_node_instances_storage(self): @@ -213,7 +213,7 @@ def create_node_instances_storage(self): type=relationship_instance['type'], target_id=relationship_instance['target_id']) relationship_instances.append(relationship_instance_model) - self.model_storage.relationship_instance.store(relationship_instance_model) + self.model_storage.relationship_instance.put(relationship_instance_model) node_instance_model = self.model_storage.node_instance.model_cls( node=node_model, @@ -224,7 +224,7 @@ def create_node_instances_storage(self): version='1.0', relationship_instances=relationship_instances) - self.model_storage.node_instance.store(node_instance_model) + self.model_storage.node_instance.put(node_instance_model) self.logger.debug('created node-instances model storage entries') def create_plugin_storage(self, plugin_id, source): @@ -258,7 +258,7 @@ def create_plugin_storage(self, plugin_id, source): supported_py_versions=plugin.get('supported_python_versions'), uploaded_at=now ) - self.model_storage.plugin.store(plugin) + self.model_storage.plugin.put(plugin) self.logger.debug('created plugin model storage entry') diff --git a/requirements.txt b/requirements.txt index e6d53936..7e87c678 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ Jinja2==2.8 shortuuid==0.4.3 CacheControl[filecache]==0.11.6 clint==0.5.1 +SQLAlchemy==1.1.4 \ No newline at end of file diff --git a/tests/mock/context.py b/tests/mock/context.py index 5fda07eb..19041407 100644 --- a/tests/mock/context.py +++ b/tests/mock/context.py @@ -15,23 +15,53 @@ from aria import application_model_storage from aria.orchestrator import context +from aria.storage.sql_mapi import SQLAlchemyModelAPI from . import models -from ..storage import InMemoryModelDriver -def simple(**kwargs): - storage = application_model_storage(InMemoryModelDriver()) - storage.setup() - storage.blueprint.store(models.get_blueprint()) - storage.deployment.store(models.get_deployment()) +def simple(api_kwargs, **kwargs): + model_storage = application_model_storage(SQLAlchemyModelAPI, api_kwargs=api_kwargs) + blueprint = models.get_blueprint() + model_storage.blueprint.put(blueprint) + deployment = models.get_deployment(blueprint) + model_storage.deployment.put(deployment) + + ################################################################################# + # Creating a simple deployment with node -> node as a graph + + dependency_node = models.get_dependency_node(deployment) + model_storage.node.put(dependency_node) + storage_dependency_node = model_storage.node.get(dependency_node.id) + + dependency_node_instance = models.get_dependency_node_instance(storage_dependency_node) + model_storage.node_instance.put(dependency_node_instance) + storage_dependency_node_instance = model_storage.node_instance.get(dependency_node_instance.id) + + dependent_node = models.get_dependent_node(deployment) + model_storage.node.put(dependent_node) + storage_dependent_node = model_storage.node.get(dependent_node.id) + + dependent_node_instance = models.get_dependent_node_instance(storage_dependent_node) + model_storage.node_instance.put(dependent_node_instance) + storage_dependent_node_instance = model_storage.node_instance.get(dependent_node_instance.id) + + relationship = models.get_relationship(storage_dependent_node, storage_dependency_node) + model_storage.relationship.put(relationship) + storage_relationship = model_storage.relationship.get(relationship.id) + relationship_instance = models.get_relationship_instance( + relationship=storage_relationship, + target_instance=storage_dependency_node_instance, + source_instance=storage_dependent_node_instance + ) + model_storage.relationship_instance.put(relationship_instance) + final_kwargs = dict( name='simple_context', - model_storage=storage, + model_storage=model_storage, resource_storage=None, - deployment_id=models.DEPLOYMENT_ID, - workflow_id=models.WORKFLOW_ID, - execution_id=models.EXECUTION_ID, + deployment_id=deployment.id, + workflow_name=models.WORKFLOW_NAME, task_max_attempts=models.TASK_MAX_ATTEMPTS, task_retry_interval=models.TASK_RETRY_INTERVAL ) diff --git a/tests/mock/models.py b/tests/mock/models.py index 327b0b9a..e2e3d2fa 100644 --- a/tests/mock/models.py +++ b/tests/mock/models.py @@ -19,24 +19,24 @@ from . import operations -DEPLOYMENT_ID = 'test_deployment_id' -BLUEPRINT_ID = 'test_blueprint_id' -WORKFLOW_ID = 'test_workflow_id' -EXECUTION_ID = 'test_execution_id' +DEPLOYMENT_NAME = 'test_deployment_id' +BLUEPRINT_NAME = 'test_blueprint_id' +WORKFLOW_NAME = 'test_workflow_id' +EXECUTION_NAME = 'test_execution_id' TASK_RETRY_INTERVAL = 1 TASK_MAX_ATTEMPTS = 1 -DEPENDENCY_NODE_ID = 'dependency_node' -DEPENDENCY_NODE_INSTANCE_ID = 'dependency_node_instance' -DEPENDENT_NODE_ID = 'dependent_node' -DEPENDENT_NODE_INSTANCE_ID = 'dependent_node_instance' +DEPENDENCY_NODE_NAME = 'dependency_node' +DEPENDENCY_NODE_INSTANCE_NAME = 'dependency_node_instance' +DEPENDENT_NODE_NAME = 'dependent_node' +DEPENDENT_NODE_INSTANCE_NAME = 'dependent_node_instance' +RELATIONSHIP_NAME = 'relationship' +RELATIONSHIP_INSTANCE_NAME = 'relationship_instance' -def get_dependency_node(): +def get_dependency_node(deployment): return models.Node( - id=DEPENDENCY_NODE_ID, - host_id=DEPENDENCY_NODE_ID, - blueprint_id=BLUEPRINT_ID, + name=DEPENDENCY_NODE_NAME, type='test_node_type', type_hierarchy=[], number_of_instances=1, @@ -44,28 +44,28 @@ def get_dependency_node(): deploy_number_of_instances=1, properties={}, operations=dict((key, {}) for key in operations.NODE_OPERATIONS), - relationships=[], min_number_of_instances=1, max_number_of_instances=1, + deployment_id=deployment.id ) -def get_dependency_node_instance(dependency_node=None): +def get_dependency_node_instance(dependency_node): return models.NodeInstance( - id=DEPENDENCY_NODE_INSTANCE_ID, - host_id=DEPENDENCY_NODE_INSTANCE_ID, - deployment_id=DEPLOYMENT_ID, + name=DEPENDENCY_NODE_INSTANCE_NAME, runtime_properties={'ip': '1.1.1.1'}, version=None, - relationship_instances=[], - node=dependency_node or get_dependency_node() + node_id=dependency_node.id, + deployment_id=dependency_node.deployment.id, + state='', + scaling_groups={} ) def get_relationship(source=None, target=None): return models.Relationship( - source_id=source.id if source is not None else DEPENDENT_NODE_ID, - target_id=target.id if target is not None else DEPENDENCY_NODE_ID, + source_node_id=source.id, + target_node_id=target.id, source_interfaces={}, source_operations=dict((key, {}) for key in operations.RELATIONSHIP_OPERATIONS), target_interfaces={}, @@ -76,23 +76,18 @@ def get_relationship(source=None, target=None): ) -def get_relationship_instance(source_instance=None, target_instance=None, relationship=None): +def get_relationship_instance(source_instance, target_instance, relationship): return models.RelationshipInstance( - target_id=target_instance.id if target_instance else DEPENDENCY_NODE_INSTANCE_ID, - target_name='test_target_name', - source_id=source_instance.id if source_instance else DEPENDENT_NODE_INSTANCE_ID, - source_name='test_source_name', - type='some_type', - relationship=relationship or get_relationship(target_instance.node - if target_instance else None) + relationship_id=relationship.id, + target_node_instance_id=target_instance.id, + source_node_instance_id=source_instance.id, ) -def get_dependent_node(relationship=None): +def get_dependent_node(deployment): return models.Node( - id=DEPENDENT_NODE_ID, - host_id=DEPENDENT_NODE_ID, - blueprint_id=BLUEPRINT_ID, + name=DEPENDENT_NODE_NAME, + deployment_id=deployment.id, type='test_node_type', type_hierarchy=[], number_of_instances=1, @@ -100,21 +95,20 @@ def get_dependent_node(relationship=None): deploy_number_of_instances=1, properties={}, operations=dict((key, {}) for key in operations.NODE_OPERATIONS), - relationships=[relationship or get_relationship()], min_number_of_instances=1, max_number_of_instances=1, ) -def get_dependent_node_instance(relationship_instance=None, dependent_node=None): +def get_dependent_node_instance(dependent_node): return models.NodeInstance( - id=DEPENDENT_NODE_INSTANCE_ID, - host_id=DEPENDENT_NODE_INSTANCE_ID, - deployment_id=DEPLOYMENT_ID, + name=DEPENDENT_NODE_INSTANCE_NAME, runtime_properties={}, version=None, - relationship_instances=[relationship_instance or get_relationship_instance()], - node=dependent_node or get_dependency_node() + node_id=dependent_node.id, + deployment_id=dependent_node.deployment.id, + state='', + scaling_groups={} ) @@ -122,7 +116,7 @@ def get_blueprint(): now = datetime.now() return models.Blueprint( plan={}, - id=BLUEPRINT_ID, + name=BLUEPRINT_NAME, description=None, created_at=now, updated_at=now, @@ -130,25 +124,31 @@ def get_blueprint(): ) -def get_execution(): +def get_execution(deployment): return models.Execution( - id=EXECUTION_ID, + deployment_id=deployment.id, + blueprint_id=deployment.blueprint.id, status=models.Execution.STARTED, - deployment_id=DEPLOYMENT_ID, - workflow_id=WORKFLOW_ID, - blueprint_id=BLUEPRINT_ID, + workflow_name=WORKFLOW_NAME, started_at=datetime.utcnow(), parameters=None ) -def get_deployment(): +def get_deployment(blueprint): now = datetime.utcnow() return models.Deployment( - id=DEPLOYMENT_ID, - description=None, + name=DEPLOYMENT_NAME, + blueprint_id=blueprint.id, + description='', created_at=now, updated_at=now, - blueprint_id=BLUEPRINT_ID, - workflows={} + workflows={}, + inputs={}, + groups={}, + permalink='', + policy_triggers={}, + policy_types={}, + outputs={}, + scaling_groups={}, ) diff --git a/tests/orchestrator/context/test_operation.py b/tests/orchestrator/context/test_operation.py index 6b3e28d6..b5f52a3f 100644 --- a/tests/orchestrator/context/test_operation.py +++ b/tests/orchestrator/context/test_operation.py @@ -23,7 +23,7 @@ from aria.orchestrator.workflows import api from aria.orchestrator.workflows.executor import thread -from tests import mock +from tests import mock, storage from . import ( op_path, op_name, @@ -34,8 +34,10 @@ @pytest.fixture -def ctx(): - return mock.context.simple() +def ctx(tmpdir): + context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield context + storage.release_sqlite_storage(context.model) @pytest.fixture @@ -50,14 +52,13 @@ def executor(): def test_node_operation_task_execution(ctx, executor): operation_name = 'aria.interfaces.lifecycle.create' - node = mock.models.get_dependency_node() + node = ctx.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME) node.operations[operation_name] = { 'operation': op_path(my_operation, module_path=__name__) } - node_instance = mock.models.get_dependency_node_instance(node) - ctx.model.node.store(node) - ctx.model.node_instance.store(node_instance) + ctx.model.node.update(node) + node_instance = ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) inputs = {'putput': True} @@ -90,26 +91,19 @@ def basic_workflow(graph, **_): def test_relationship_operation_task_execution(ctx, executor): operation_name = 'aria.interfaces.relationship_lifecycle.postconfigure' - - dependency_node = mock.models.get_dependency_node() - dependency_node_instance = mock.models.get_dependency_node_instance() - relationship = mock.models.get_relationship(target=dependency_node) + relationship = ctx.model.relationship.list()[0] relationship.source_operations[operation_name] = { 'operation': op_path(my_operation, module_path=__name__) } - relationship_instance = mock.models.get_relationship_instance( - target_instance=dependency_node_instance, - relationship=relationship) - dependent_node = mock.models.get_dependent_node() - dependent_node_instance = mock.models.get_dependent_node_instance( - relationship_instance=relationship_instance, - dependent_node=dependency_node) - ctx.model.node.store(dependency_node) - ctx.model.node_instance.store(dependency_node_instance) - ctx.model.relationship.store(relationship) - ctx.model.relationship_instance.store(relationship_instance) - ctx.model.node.store(dependent_node) - ctx.model.node_instance.store(dependent_node_instance) + ctx.model.relationship.update(relationship) + relationship_instance = ctx.model.relationship_instance.list()[0] + + dependency_node = ctx.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME) + dependency_node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + dependent_node = ctx.model.node.get_by_name(mock.models.DEPENDENT_NODE_NAME) + dependent_node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENT_NODE_INSTANCE_NAME) inputs = {'putput': True} @@ -146,11 +140,49 @@ def basic_workflow(graph, **_): assert operation_context.source_node_instance == dependent_node_instance +def test_invalid_task_operation_id(ctx, executor): + """ + Checks that the right id is used. The task created with id == 1, thus running the task on + node_instance with id == 2. will check that indeed the node_instance uses the correct id. + :param ctx: + :param executor: + :return: + """ + operation_name = 'aria.interfaces.lifecycle.create' + other_node_instance, node_instance = ctx.model.node_instance.list() + assert other_node_instance.id == 1 + assert node_instance.id == 2 + + node = node_instance.node + node.operations[operation_name] = { + 'operation': op_path(get_node_instance_id, module_path=__name__) + + } + ctx.model.node.update(node) + + @workflow + def basic_workflow(graph, **_): + graph.add_tasks( + api.task.OperationTask.node_instance(name=operation_name, instance=node_instance) + ) + + execute(workflow_func=basic_workflow, workflow_context=ctx, executor=executor) + + op_node_instance_id = global_test_holder[op_name(node_instance, operation_name)] + assert op_node_instance_id == node_instance.id + assert op_node_instance_id != other_node_instance.id + + @operation def my_operation(ctx, **_): global_test_holder[ctx.name] = ctx +@operation +def get_node_instance_id(ctx, **_): + global_test_holder[ctx.name] = ctx.node_instance.id + + @pytest.fixture(autouse=True) def cleanup(): global_test_holder.clear() diff --git a/tests/orchestrator/context/test_toolbelt.py b/tests/orchestrator/context/test_toolbelt.py index 547e62b7..da466961 100644 --- a/tests/orchestrator/context/test_toolbelt.py +++ b/tests/orchestrator/context/test_toolbelt.py @@ -21,7 +21,7 @@ from aria.orchestrator.workflows.executor import thread from aria.orchestrator.context.toolbelt import RelationshipToolBelt -from tests import mock +from tests import mock, storage from . import ( op_path, op_name, @@ -32,8 +32,10 @@ @pytest.fixture -def workflow_context(): - return mock.context.simple() +def workflow_context(tmpdir): + context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield context + storage.release_sqlite_storage(context.model) @pytest.fixture @@ -45,63 +47,39 @@ def executor(): result.close() -def _create_simple_model_in_storage(workflow_context): - dependency_node = mock.models.get_dependency_node() - dependency_node_instance = mock.models.get_dependency_node_instance( - dependency_node=dependency_node) - relationship = mock.models.get_relationship(target=dependency_node) - relationship_instance = mock.models.get_relationship_instance( - target_instance=dependency_node_instance, relationship=relationship) - dependent_node = mock.models.get_dependent_node() - dependent_node_instance = mock.models.get_dependent_node_instance( - relationship_instance=relationship_instance, dependent_node=dependency_node) - workflow_context.model.node.store(dependency_node) - workflow_context.model.node_instance.store(dependency_node_instance) - workflow_context.model.relationship.store(relationship) - workflow_context.model.relationship_instance.store(relationship_instance) - workflow_context.model.node.store(dependent_node) - workflow_context.model.node_instance.store(dependent_node_instance) - return dependency_node, dependency_node_instance, \ - dependent_node, dependent_node_instance, \ - relationship, relationship_instance +def _get_elements(workflow_context): + dependency_node = workflow_context.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME) + dependency_node.host_id = dependency_node.id + workflow_context.model.node.update(dependency_node) + dependency_node_instance = workflow_context.model.node_instance.get_by_name( + mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + dependency_node_instance.host_id = dependency_node_instance.id + workflow_context.model.node_instance.update(dependency_node_instance) -def test_host_ip(workflow_context, executor): - operation_name = 'aria.interfaces.lifecycle.create' - dependency_node, dependency_node_instance, _, _, _, _ = \ - _create_simple_model_in_storage(workflow_context) - dependency_node.operations[operation_name] = { - 'operation': op_path(host_ip, module_path=__name__) - - } - workflow_context.model.node.store(dependency_node) - inputs = {'putput': True} - - @workflow - def basic_workflow(graph, **_): - graph.add_tasks( - api.task.OperationTask.node_instance( - instance=dependency_node_instance, - name=operation_name, - inputs=inputs - ) - ) + dependent_node = workflow_context.model.node.get_by_name(mock.models.DEPENDENT_NODE_NAME) + dependent_node.host_id = dependency_node.id + workflow_context.model.node.update(dependent_node) - execute(workflow_func=basic_workflow, workflow_context=workflow_context, executor=executor) + dependent_node_instance = workflow_context.model.node_instance.get_by_name( + mock.models.DEPENDENT_NODE_INSTANCE_NAME) + dependent_node_instance.host_id = dependent_node_instance.id + workflow_context.model.node_instance.update(dependent_node_instance) - assert global_test_holder.get('host_ip') == \ - dependency_node_instance.runtime_properties.get('ip') + relationship = workflow_context.model.relationship.list()[0] + relationship_instance = workflow_context.model.relationship_instance.list()[0] + return dependency_node, dependency_node_instance, dependent_node, dependent_node_instance, \ + relationship, relationship_instance -def test_dependent_node_instances(workflow_context, executor): +def test_host_ip(workflow_context, executor): operation_name = 'aria.interfaces.lifecycle.create' - dependency_node, dependency_node_instance, _, dependent_node_instance, _, _ = \ - _create_simple_model_in_storage(workflow_context) + dependency_node, dependency_node_instance, _, _, _, _ = _get_elements(workflow_context) dependency_node.operations[operation_name] = { - 'operation': op_path(dependent_nodes, module_path=__name__) + 'operation': op_path(host_ip, module_path=__name__) } - workflow_context.model.node.store(dependency_node) + workflow_context.model.node.put(dependency_node) inputs = {'putput': True} @workflow @@ -116,18 +94,18 @@ def basic_workflow(graph, **_): execute(workflow_func=basic_workflow, workflow_context=workflow_context, executor=executor) - assert list(global_test_holder.get('dependent_node_instances', [])) == \ - list([dependent_node_instance]) + assert global_test_holder.get('host_ip') == \ + dependency_node_instance.runtime_properties.get('ip') def test_relationship_tool_belt(workflow_context, executor): operation_name = 'aria.interfaces.relationship_lifecycle.postconfigure' _, _, _, _, relationship, relationship_instance = \ - _create_simple_model_in_storage(workflow_context) + _get_elements(workflow_context) relationship.source_operations[operation_name] = { 'operation': op_path(relationship_operation, module_path=__name__) } - workflow_context.model.relationship.store(relationship) + workflow_context.model.relationship.put(relationship) inputs = {'putput': True} @@ -152,16 +130,12 @@ def test_wrong_model_toolbelt(): with pytest.raises(RuntimeError): context.toolbelt(None) + @operation(toolbelt=True) def host_ip(toolbelt, **_): global_test_holder['host_ip'] = toolbelt.host_ip -@operation(toolbelt=True) -def dependent_nodes(toolbelt, **_): - global_test_holder['dependent_node_instances'] = list(toolbelt.dependent_node_instances) - - @operation(toolbelt=True) def relationship_operation(ctx, toolbelt, **_): global_test_holder[ctx.name] = toolbelt diff --git a/tests/orchestrator/context/test_workflow.py b/tests/orchestrator/context/test_workflow.py index 258f0c59..496c1ffc 100644 --- a/tests/orchestrator/context/test_workflow.py +++ b/tests/orchestrator/context/test_workflow.py @@ -19,20 +19,19 @@ from aria import application_model_storage from aria.orchestrator import context - +from aria.storage.sql_mapi import SQLAlchemyModelAPI +from tests import storage as test_storage from tests.mock import models -from tests.storage import InMemoryModelDriver class TestWorkflowContext(object): def test_execution_creation_on_workflow_context_creation(self, storage): - self._create_ctx(storage) - execution = storage.execution.get(models.EXECUTION_ID) - assert execution.id == models.EXECUTION_ID - assert execution.deployment_id == models.DEPLOYMENT_ID - assert execution.workflow_id == models.WORKFLOW_ID - assert execution.blueprint_id == models.BLUEPRINT_ID + ctx = self._create_ctx(storage) + execution = storage.execution.get(ctx.execution.id) # pylint: disable=no-member + assert execution.deployment == storage.deployment.get_by_name(models.DEPLOYMENT_NAME) + assert execution.workflow_name == models.WORKFLOW_NAME + assert execution.blueprint == storage.blueprint.get_by_name(models.BLUEPRINT_NAME) assert execution.status == storage.execution.model_cls.PENDING assert execution.parameters == {} assert execution.created_at <= datetime.utcnow() @@ -43,13 +42,17 @@ def test_subsequent_workflow_context_creation_do_not_fail(self, storage): @staticmethod def _create_ctx(storage): + """ + + :param storage: + :return WorkflowContext: + """ return context.workflow.WorkflowContext( name='simple_context', model_storage=storage, resource_storage=None, - deployment_id=models.DEPLOYMENT_ID, - workflow_id=models.WORKFLOW_ID, - execution_id=models.EXECUTION_ID, + deployment_id=storage.deployment.get_by_name(models.DEPLOYMENT_NAME).id, + workflow_name=models.WORKFLOW_NAME, task_max_attempts=models.TASK_MAX_ATTEMPTS, task_retry_interval=models.TASK_RETRY_INTERVAL ) @@ -57,8 +60,10 @@ def _create_ctx(storage): @pytest.fixture(scope='function') def storage(): - result = application_model_storage(InMemoryModelDriver()) - result.setup() - result.blueprint.store(models.get_blueprint()) - result.deployment.store(models.get_deployment()) - return result + api_kwargs = test_storage.get_sqlite_api_kwargs() + workflow_storage = application_model_storage(SQLAlchemyModelAPI, api_kwargs=api_kwargs) + workflow_storage.blueprint.put(models.get_blueprint()) + blueprint = workflow_storage.blueprint.get_by_name(models.BLUEPRINT_NAME) + workflow_storage.deployment.put(models.get_deployment(blueprint)) + yield workflow_storage + test_storage.release_sqlite_storage(workflow_storage) diff --git a/tests/orchestrator/workflows/api/test_task.py b/tests/orchestrator/workflows/api/test_task.py index 85369021..1a903382 100644 --- a/tests/orchestrator/workflows/api/test_task.py +++ b/tests/orchestrator/workflows/api/test_task.py @@ -19,61 +19,38 @@ from aria.orchestrator import context from aria.orchestrator.workflows import api -from tests import mock +from tests import mock, storage -@pytest.fixture() +@pytest.fixture def ctx(): """ Create the following graph in storage: dependency_node <------ dependent_node :return: """ - simple_context = mock.context.simple() - dependency_node = mock.models.get_dependency_node() - dependency_node_instance = mock.models.get_dependency_node_instance( - dependency_node=dependency_node) - - relationship = mock.models.get_relationship(dependency_node) - relationship_instance = mock.models.get_relationship_instance( - relationship=relationship, - target_instance=dependency_node_instance - ) - - dependent_node = mock.models.get_dependent_node(relationship) - dependent_node_instance = mock.models.get_dependent_node_instance( - dependent_node=dependent_node, - relationship_instance=relationship_instance - ) - - simple_context.model.node.store(dependent_node) - simple_context.model.node.store(dependency_node) - simple_context.model.node_instance.store(dependent_node_instance) - simple_context.model.node_instance.store(dependency_node_instance) - simple_context.model.relationship.store(relationship) - simple_context.model.relationship_instance.store(relationship_instance) - simple_context.model.execution.store(mock.models.get_execution()) - simple_context.model.deployment.store(mock.models.get_deployment()) - - return simple_context + simple_context = mock.context.simple(storage.get_sqlite_api_kwargs()) + simple_context.model.execution.put(mock.models.get_execution(simple_context.deployment)) + yield simple_context + storage.release_sqlite_storage(simple_context.model) class TestOperationTask(object): - def test_node_operation_task_creation(self): - workflow_context = mock.context.simple() - + def test_node_operation_task_creation(self, ctx): operation_name = 'aria.interfaces.lifecycle.create' op_details = {'operation': True} - node = mock.models.get_dependency_node() + node = ctx.model.node.get_by_name(mock.models.DEPENDENT_NODE_NAME) node.operations[operation_name] = op_details - node_instance = mock.models.get_dependency_node_instance(dependency_node=node) + ctx.model.node.update(node) + node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENT_NODE_INSTANCE_NAME) inputs = {'inputs': True} max_attempts = 10 retry_interval = 10 ignore_failure = True - with context.workflow.current.push(workflow_context): + with context.workflow.current.push(ctx): api_task = api.task.OperationTask.node_instance( name=operation_name, instance=node_instance, @@ -90,19 +67,17 @@ def test_node_operation_task_creation(self): assert api_task.max_attempts == max_attempts assert api_task.ignore_failure == ignore_failure - def test_relationship_operation_task_creation(self): - workflow_context = mock.context.simple() - + def test_relationship_operation_task_creation(self, ctx): operation_name = 'aria.interfaces.relationship_lifecycle.preconfigure' op_details = {'operation': True} - relationship = mock.models.get_relationship() + relationship = ctx.model.relationship.list()[0] relationship.source_operations[operation_name] = op_details - relationship_instance = mock.models.get_relationship_instance(relationship=relationship) + relationship_instance = ctx.model.relationship_instance.list()[0] inputs = {'inputs': True} max_attempts = 10 retry_interval = 10 - with context.workflow.current.push(workflow_context): + with context.workflow.current.push(ctx): api_task = api.task.OperationTask.relationship_instance( name=operation_name, instance=relationship_instance, @@ -118,18 +93,19 @@ def test_relationship_operation_task_creation(self): assert api_task.retry_interval == retry_interval assert api_task.max_attempts == max_attempts - def test_operation_task_default_values(self): - workflow_context = mock.context.simple(task_ignore_failure=True) - with context.workflow.current.push(workflow_context): - model_task = api.task.OperationTask( + def test_operation_task_default_values(self, ctx): + dependency_node_instance = ctx.model.node_instance.get_by_name( + mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + with context.workflow.current.push(ctx): + task = api.task.OperationTask( name='stub', operation_mapping='', - actor=mock.models.get_dependency_node_instance()) + actor=dependency_node_instance) - assert model_task.inputs == {} - assert model_task.retry_interval == workflow_context._task_retry_interval - assert model_task.max_attempts == workflow_context._task_max_attempts - assert model_task.ignore_failure == workflow_context._task_ignore_failure + assert task.inputs == {} + assert task.retry_interval == ctx._task_retry_interval + assert task.max_attempts == ctx._task_max_attempts + assert task.ignore_failure == ctx._task_ignore_failure class TestWorkflowTask(object): diff --git a/tests/orchestrator/workflows/builtin/__init__.py b/tests/orchestrator/workflows/builtin/__init__.py index e100432a..26ba82f9 100644 --- a/tests/orchestrator/workflows/builtin/__init__.py +++ b/tests/orchestrator/workflows/builtin/__init__.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - from tests import mock + def assert_node_install_operations(operations, with_relationships=False): if with_relationships: all_operations = [ @@ -51,36 +51,3 @@ def assert_node_uninstall_operations(operations, with_relationships=False): else: for i, operation in enumerate(operations): assert operation.name.startswith(mock.operations.NODE_OPERATIONS_UNINSTALL[i]) - - -def ctx_with_basic_graph(): - """ - Create the following graph in storage: - dependency_node <------ dependent_node - :return: - """ - simple_context = mock.context.simple() - dependency_node = mock.models.get_dependency_node() - dependency_node_instance = mock.models.get_dependency_node_instance( - dependency_node=dependency_node) - - relationship = mock.models.get_relationship(dependency_node) - relationship_instance = mock.models.get_relationship_instance( - relationship=relationship, - target_instance=dependency_node_instance - ) - - dependent_node = mock.models.get_dependent_node(relationship) - dependent_node_instance = mock.models.get_dependent_node_instance( - dependent_node=dependent_node, - relationship_instance=relationship_instance - ) - - simple_context.model.node.store(dependent_node) - simple_context.model.node.store(dependency_node) - simple_context.model.node_instance.store(dependent_node_instance) - simple_context.model.node_instance.store(dependency_node_instance) - simple_context.model.relationship.store(relationship) - simple_context.model.relationship_instance.store(relationship_instance) - - return simple_context diff --git a/tests/orchestrator/workflows/builtin/test_execute_operation.py b/tests/orchestrator/workflows/builtin/test_execute_operation.py index 83e0d4d5..b7e56789 100644 --- a/tests/orchestrator/workflows/builtin/test_execute_operation.py +++ b/tests/orchestrator/workflows/builtin/test_execute_operation.py @@ -19,17 +19,20 @@ from aria.orchestrator.workflows.builtin.execute_operation import execute_operation from tests import mock -from . import ctx_with_basic_graph +from tests import storage @pytest.fixture -def ctx(): - return ctx_with_basic_graph() +def ctx(tmpdir): + context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield context + storage.release_sqlite_storage(context.model) def test_execute_operation(ctx): + node_instance = ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + operation_name = mock.operations.NODE_OPERATIONS_INSTALL[0] - node_instance_id = 'dependency_node_instance' execute_tasks = list( task.WorkflowTask( @@ -41,11 +44,13 @@ def test_execute_operation(ctx): run_by_dependency_order=False, type_names=[], node_ids=[], - node_instance_ids=[node_instance_id] + node_instance_ids=[node_instance.id] ).topological_order() ) assert len(execute_tasks) == 1 - assert execute_tasks[0].name == '{0}.{1}'.format(operation_name, node_instance_id) + assert execute_tasks[0].name == '{0}.{1}'.format(operation_name, node_instance.id) + + # TODO: add more scenarios diff --git a/tests/orchestrator/workflows/builtin/test_heal.py b/tests/orchestrator/workflows/builtin/test_heal.py index 940194b1..97121b9c 100644 --- a/tests/orchestrator/workflows/builtin/test_heal.py +++ b/tests/orchestrator/workflows/builtin/test_heal.py @@ -18,18 +18,25 @@ from aria.orchestrator.workflows.api import task from aria.orchestrator.workflows.builtin.heal import heal +from tests import mock, storage + from . import (assert_node_install_operations, - assert_node_uninstall_operations, - ctx_with_basic_graph) + assert_node_uninstall_operations) @pytest.fixture -def ctx(): - return ctx_with_basic_graph() +def ctx(tmpdir): + context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield context + storage.release_sqlite_storage(context.model) def test_heal_dependent_node(ctx): - heal_graph = task.WorkflowTask(heal, ctx=ctx, node_instance_id='dependent_node_instance') + dependent_node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENT_NODE_INSTANCE_NAME) + dependent_node_instance.host_id = dependent_node_instance.id + ctx.model.node_instance.update(dependent_node_instance) + heal_graph = task.WorkflowTask(heal, ctx=ctx, node_instance_id=dependent_node_instance.id) assert len(list(heal_graph.tasks)) == 2 uninstall_subgraph, install_subgraph = list(heal_graph.topological_order(reverse=True)) @@ -54,7 +61,11 @@ def test_heal_dependent_node(ctx): def test_heal_dependency_node(ctx): - heal_graph = task.WorkflowTask(heal, ctx=ctx, node_instance_id='dependency_node_instance') + dependency_node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + dependency_node_instance.host_id = dependency_node_instance.id + ctx.model.node_instance.update(dependency_node_instance) + heal_graph = task.WorkflowTask(heal, ctx=ctx, node_instance_id=dependency_node_instance.id) # both subgraphs should contain un\install for both the dependent and the dependency assert len(list(heal_graph.tasks)) == 2 uninstall_subgraph, install_subgraph = list(heal_graph.topological_order(reverse=True)) diff --git a/tests/orchestrator/workflows/builtin/test_install.py b/tests/orchestrator/workflows/builtin/test_install.py index 3b23c5ac..789a1616 100644 --- a/tests/orchestrator/workflows/builtin/test_install.py +++ b/tests/orchestrator/workflows/builtin/test_install.py @@ -12,22 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest -from aria.orchestrator.workflows.builtin.install import install from aria.orchestrator.workflows.api import task +from aria.orchestrator.workflows.builtin.install import install -from . import (assert_node_install_operations, - ctx_with_basic_graph) +from tests import mock +from tests import storage + +from . import assert_node_install_operations @pytest.fixture -def ctx(): - return ctx_with_basic_graph() +def ctx(tmpdir): + context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield context + storage.release_sqlite_storage(context.model) def test_install(ctx): + install_tasks = list(task.WorkflowTask(install, ctx=ctx).topological_order(True)) assert len(install_tasks) == 2 diff --git a/tests/orchestrator/workflows/builtin/test_uninstall.py b/tests/orchestrator/workflows/builtin/test_uninstall.py index 889e1d27..126c4cf3 100644 --- a/tests/orchestrator/workflows/builtin/test_uninstall.py +++ b/tests/orchestrator/workflows/builtin/test_uninstall.py @@ -18,16 +18,21 @@ from aria.orchestrator.workflows.api import task from aria.orchestrator.workflows.builtin.uninstall import uninstall -from . import (assert_node_uninstall_operations, - ctx_with_basic_graph) +from tests import mock +from tests import storage + +from . import assert_node_uninstall_operations @pytest.fixture -def ctx(): - return ctx_with_basic_graph() +def ctx(tmpdir): + context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield context + storage.release_sqlite_storage(context.model) def test_uninstall(ctx): + uninstall_tasks = list(task.WorkflowTask(uninstall, ctx=ctx).topological_order(True)) assert len(uninstall_tasks) == 2 diff --git a/tests/orchestrator/workflows/core/test_engine.py b/tests/orchestrator/workflows/core/test_engine.py index 1b00bf66..baded7f6 100644 --- a/tests/orchestrator/workflows/core/test_engine.py +++ b/tests/orchestrator/workflows/core/test_engine.py @@ -12,19 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import time import threading from datetime import datetime import pytest -import aria from aria.orchestrator import ( events, workflow, operation, - context ) from aria.storage import models from aria.orchestrator.workflows import ( @@ -34,9 +31,7 @@ from aria.orchestrator.workflows.core import engine from aria.orchestrator.workflows.executor import thread - -import tests.storage -from tests import mock +from tests import mock, storage global_test_holder = {} @@ -65,11 +60,11 @@ def _op(func, ctx, max_attempts=None, retry_interval=None, ignore_failure=None): - node_instance = ctx.model.node_instance.get('dependency_node_instance') + node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) node_instance.node.operations['aria.interfaces.lifecycle.create'] = { 'operation': '{name}.{func.__name__}'.format(name=__name__, func=func) } - ctx.model.node_instance.store(node_instance) return api.task.OperationTask.node_instance( instance=node_instance, name='aria.interfaces.lifecycle.create', @@ -79,14 +74,14 @@ def _op(func, ctx, ignore_failure=ignore_failure ) - @pytest.fixture(scope='function', autouse=True) + @pytest.fixture(autouse=True) def globals_cleanup(self): try: yield finally: global_test_holder.clear() - @pytest.fixture(scope='function', autouse=True) + @pytest.fixture(autouse=True) def signals_registration(self, ): def sent_task_handler(*args, **kwargs): calls = global_test_holder.setdefault('sent_task_signal_calls', 0) @@ -119,7 +114,7 @@ def cancel_workflow_handler(workflow_context, *args, **kwargs): events.on_cancelled_workflow_signal.disconnect(cancel_workflow_handler) events.sent_task_signal.disconnect(sent_task_handler) - @pytest.fixture(scope='function') + @pytest.fixture def executor(self): result = thread.ThreadExecutor() try: @@ -127,27 +122,13 @@ def executor(self): finally: result.close() - @pytest.fixture(scope='function') - def workflow_context(self): - model_storage = aria.application_model_storage(tests.storage.InMemoryModelDriver()) - model_storage.setup() - blueprint = mock.models.get_blueprint() - deployment = mock.models.get_deployment() - model_storage.blueprint.store(blueprint) - model_storage.deployment.store(deployment) - node = mock.models.get_dependency_node() - node_instance = mock.models.get_dependency_node_instance(node) - model_storage.node.store(node) - model_storage.node_instance.store(node_instance) - result = context.workflow.WorkflowContext( - name='test', - model_storage=model_storage, - resource_storage=None, - deployment_id=deployment.id, - workflow_id='name') - result.states = [] - result.exception = None - return result + @pytest.fixture + def workflow_context(self, tmpdir): + workflow_context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + workflow_context.states = [] + workflow_context.exception = None + yield workflow_context + storage.release_sqlite_storage(workflow_context.model) class TestEngine(BaseTest): @@ -245,7 +226,7 @@ def mock_workflow(ctx, graph): executor=executor) t = threading.Thread(target=eng.execute) t.start() - time.sleep(1) + time.sleep(10) eng.cancel_execution() t.join(timeout=30) assert workflow_context.states == ['start', 'cancel'] diff --git a/tests/orchestrator/workflows/core/test_task.py b/tests/orchestrator/workflows/core/test_task.py index 6a4c8ac6..c5725017 100644 --- a/tests/orchestrator/workflows/core/test_task.py +++ b/tests/orchestrator/workflows/core/test_task.py @@ -26,26 +26,14 @@ exceptions, ) -from tests import mock +from tests import mock, storage @pytest.fixture -def ctx(): - simple_context = mock.context.simple() - - blueprint = mock.models.get_blueprint() - deployment = mock.models.get_deployment() - node = mock.models.get_dependency_node() - node_instance = mock.models.get_dependency_node_instance(node) - execution = mock.models.get_execution() - - simple_context.model.blueprint.store(blueprint) - simple_context.model.deployment.store(deployment) - simple_context.model.node.store(node) - simple_context.model.node_instance.store(node_instance) - simple_context.model.execution.store(execution) - - return simple_context +def ctx(tmpdir): + context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield context + storage.release_sqlite_storage(context.model) class TestOperationTask(object): @@ -62,9 +50,10 @@ def _create_operation_task(self, ctx, node_instance): return api_task, core_task def test_operation_task_creation(self, ctx): - node_instance = ctx.model.node_instance.get(mock.models.DEPENDENCY_NODE_INSTANCE_ID) + node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) api_task, core_task = self._create_operation_task(ctx, node_instance) - storage_task = ctx.model.task.get(core_task.id) + storage_task = ctx.model.task.get_by_name(core_task.name) assert core_task.model_task == storage_task assert core_task.name == api_task.name @@ -73,7 +62,8 @@ def test_operation_task_creation(self, ctx): assert core_task.inputs == api_task.inputs == storage_task.inputs def test_operation_task_edit_locked_attribute(self, ctx): - node_instance = ctx.model.node_instance.get(mock.models.DEPENDENCY_NODE_INSTANCE_ID) + node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) _, core_task = self._create_operation_task(ctx, node_instance) now = datetime.utcnow() @@ -89,7 +79,8 @@ def test_operation_task_edit_locked_attribute(self, ctx): core_task.due_at = now def test_operation_task_edit_attributes(self, ctx): - node_instance = ctx.model.node_instance.get(mock.models.DEPENDENCY_NODE_INSTANCE_ID) + node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) _, core_task = self._create_operation_task(ctx, node_instance) future_time = datetime.utcnow() + timedelta(seconds=3) @@ -99,7 +90,7 @@ def test_operation_task_edit_attributes(self, ctx): core_task.started_at = future_time core_task.ended_at = future_time core_task.retry_count = 2 - core_task.eta = future_time + core_task.due_at = future_time assert core_task.status != core_task.STARTED assert core_task.started_at != future_time assert core_task.ended_at != future_time @@ -110,4 +101,4 @@ def test_operation_task_edit_attributes(self, ctx): assert core_task.started_at == future_time assert core_task.ended_at == future_time assert core_task.retry_count == 2 - assert core_task.eta == future_time + assert core_task.due_at == future_time diff --git a/tests/orchestrator/workflows/core/test_task_graph_into_exececution_graph.py b/tests/orchestrator/workflows/core/test_task_graph_into_exececution_graph.py index a179e491..18540f44 100644 --- a/tests/orchestrator/workflows/core/test_task_graph_into_exececution_graph.py +++ b/tests/orchestrator/workflows/core/test_task_graph_into_exececution_graph.py @@ -19,20 +19,14 @@ from aria.orchestrator.workflows import api, core from tests import mock +from tests import storage def test_task_graph_into_execution_graph(): operation_name = 'aria.interfaces.lifecycle.create' - task_context = mock.context.simple() - node = mock.models.get_dependency_node() - node_instance = mock.models.get_dependency_node_instance() - deployment = mock.models.get_deployment() - execution = mock.models.get_execution() - task_context.model.node.store(node) - task_context.model.node_instance.store(node_instance) - task_context.model.deployment.store(deployment) - task_context.model.execution.store(execution) - + task_context = mock.context.simple(storage.get_sqlite_api_kwargs()) + node_instance = \ + task_context.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) def sub_workflow(name, **_): return api.task_graph.TaskGraph(name) @@ -91,6 +85,7 @@ def sub_workflow(name, **_): simple_after_task) assert isinstance(_get_task_by_name(execution_tasks[6], execution_graph), core.task.EndWorkflowTask) + storage.release_sqlite_storage(task_context.model) def _assert_execution_is_api_task(execution_task, api_task): diff --git a/tests/requirements.txt b/tests/requirements.txt index cda295a7..0e4740fa 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -15,4 +15,4 @@ mock==1.0.1 pylint==1.6.4 pytest==3.0.2 pytest-cov==2.3.1 -pytest-mock==1.2 +pytest-mock==1.2 \ No newline at end of file diff --git a/tests/storage/__init__.py b/tests/storage/__init__.py index 9bf48cc2..edff982a 100644 --- a/tests/storage/__init__.py +++ b/tests/storage/__init__.py @@ -12,42 +12,69 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +import platform from tempfile import mkdtemp from shutil import rmtree -from aria.storage import ModelDriver +from sqlalchemy import ( + create_engine, + orm) +from sqlalchemy.orm import scoped_session +from sqlalchemy.pool import StaticPool + +from aria.storage import structures -class InMemoryModelDriver(ModelDriver): - def __init__(self, **kwargs): - super(InMemoryModelDriver, self).__init__(**kwargs) - self.storage = {} +class TestFileSystem(object): - def create(self, name, *args, **kwargs): - self.storage[name] = {} + def setup_method(self): + self.path = mkdtemp('{0}'.format(self.__class__.__name__)) - def get(self, name, entry_id, **kwargs): - return self.storage[name][entry_id].copy() + def teardown_method(self): + rmtree(self.path, ignore_errors=True) - def store(self, name, entry_id, entry, **kwargs): - self.storage[name][entry_id] = entry - def delete(self, name, entry_id, **kwargs): - self.storage[name].pop(entry_id) +def get_sqlite_api_kwargs(base_dir=None, filename='db.sqlite'): + """ + Create sql params. works in in-memory and in filesystem mode. + If base_dir is passed, the mode will be filesystem mode. while the default mode is in-memory. + :param str base_dir: The base dir for the filesystem memory file. + :param str filename: the file name - defaults to 'db.sqlite'. + :return: + """ + if base_dir is not None: + uri = 'sqlite:///{platform_char}{path}'.format( + # Handles the windows behavior where there is not root, but drivers. + # Thus behaving as relative path. + platform_char='' if 'Windows' in platform.system() else '/', - def iter(self, name, **kwargs): - for item in self.storage[name].itervalues(): - yield item.copy() + path=os.path.join(base_dir, filename)) + engine_kwargs = {} + else: + uri = 'sqlite:///:memory:' + engine_kwargs = dict(connect_args={'check_same_thread': False}, + poolclass=StaticPool) - def update(self, name, entry_id, **kwargs): - self.storage[name][entry_id].update(**kwargs) + engine = create_engine(uri, **engine_kwargs) + session_factory = orm.sessionmaker(bind=engine) + session = scoped_session(session_factory=session_factory) if base_dir else session_factory() + structures.Model.metadata.create_all(engine) + return dict(engine=engine, session=session) -class TestFileSystem(object): - def setup_method(self): - self.path = mkdtemp('{0}'.format(self.__class__.__name__)) +def release_sqlite_storage(storage): + """ + Drops the tables and clears the session + :param storage: + :return: + """ + mapis = storage.registered.values() - def teardown_method(self): - rmtree(self.path, ignore_errors=True) + if mapis: + for session in set(mapi._session for mapi in mapis): + session.rollback() + session.close() + for engine in set(mapi._engine for mapi in mapis): + structures.Model.metadata.drop_all(engine) diff --git a/tests/storage/test_drivers.py b/tests/storage/test_drivers.py deleted file mode 100644 index dccbe98a..00000000 --- a/tests/storage/test_drivers.py +++ /dev/null @@ -1,135 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import pytest - -from aria.storage.drivers import FileSystemModelDriver, Driver, ModelDriver, ResourceDriver -from aria.storage.exceptions import StorageError - -from . import InMemoryModelDriver, TestFileSystem - - -def test_base_storage_driver(): - driver = Driver() - driver.connect() - driver.disconnect() - driver.create('name') - with driver as connection: - assert driver is connection - with pytest.raises(StorageError): - with driver: - raise StorageError() - - -def test_model_base_driver(): - driver = ModelDriver() - with pytest.raises(NotImplementedError): - driver.get('name', 'id') - with pytest.raises(NotImplementedError): - driver.store('name', entry={}, entry_id=None) - with pytest.raises(NotImplementedError): - driver.update('name', 'id', update_field=1) - with pytest.raises(NotImplementedError): - driver.delete('name', 'id') - with pytest.raises(NotImplementedError): - driver.iter('name') - - -def test_resource_base_driver(): - driver = ResourceDriver() - with pytest.raises(NotImplementedError): - driver.download('name', 'id', destination='dest') - with pytest.raises(NotImplementedError): - driver.upload('name', 'id', source='') - with pytest.raises(NotImplementedError): - driver.data('name', 'id') - - -def test_custom_driver(): - entry_dict = { - 'id': 'entry_id', - 'entry_value': 'entry_value' - } - - with InMemoryModelDriver() as driver: - driver.create('entry') - assert driver.storage['entry'] == {} - - driver.store(name='entry', entry=entry_dict, entry_id=entry_dict['id']) - assert driver.get(name='entry', entry_id='entry_id') == entry_dict - - assert list(node for node in driver.iter('entry')) == [entry_dict] - - driver.update(name='entry', entry_id=entry_dict['id'], entry_value='new_value') - assert driver.get(name='entry', entry_id='entry_id') == entry_dict - - driver.delete(name='entry', entry_id='entry_id') - - with pytest.raises(KeyError): - driver.get(name='entry', entry_id='entry_id') - - -class TestFileSystemDriver(TestFileSystem): - - def setup_method(self): - super(TestFileSystemDriver, self).setup_method() - self.driver = FileSystemModelDriver(directory=self.path) - - def test_name(self): - assert repr(self.driver) == ( - 'FileSystemModelDriver(directory={self.path})'.format(self=self)) - - def test_create(self): - self.driver.create(name='node') - assert os.path.exists(os.path.join(self.path, 'node')) - - def test_store(self): - self.test_create() - self.driver.store(name='node', entry_id='test_id', entry={'test': 'test'}) - assert os.path.exists(os.path.join(self.path, 'node', 'test_id')) - - def test_update(self): - self.test_store() - self.driver.update(name='node', entry_id='test_id', test='updated_test') - entry = self.driver.get(name='node', entry_id='test_id') - assert entry == {'test': 'updated_test'} - - def test_get(self): - self.test_store() - entry = self.driver.get(name='node', entry_id='test_id') - assert entry == {'test': 'test'} - - def test_delete(self): - self.test_store() - self.driver.delete(name='node', entry_id='test_id') - assert not os.path.exists(os.path.join(self.path, 'node', 'test_id')) - - def test_iter(self): - self.test_create() - entries = [ - {'test': 'test0'}, - {'test': 'test1'}, - {'test': 'test2'}, - {'test': 'test3'}, - {'test': 'test4'}, - ] - for entry_id, entry in enumerate(entries): - self.driver.store('node', str(entry_id), entry) - - for entry in self.driver.iter('node'): - entries.pop(entries.index(entry)) - - assert not entries diff --git a/tests/storage/test_field.py b/tests/storage/test_field.py deleted file mode 100644 index cab218fc..00000000 --- a/tests/storage/test_field.py +++ /dev/null @@ -1,124 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from aria.storage.structures import ( - Field, - IterField, - PointerField, - IterPointerField, - Model, -) - - -def model_factory(): - class TestModel(Model): - id = Field(default='test_id') - return TestModel() - - -def test_base_field(): - field = Field() - assert vars(field) == vars(Field(type=None, choices=(), default=Field.NO_DEFAULT)) - - -def test_type_check(): - field = Field(type=int) - assert vars(field) == vars(Field(type=int, choices=(), default=Field.NO_DEFAULT)) - with pytest.raises(TypeError): - field.validate_instance('field', 'any_value', int) - field.validate_instance('field', 1, int) - - -def test_field_choices(): - field = Field(choices=[1, 2]) - assert vars(field) == vars(Field(type=None, choices=[1, 2], default=Field.NO_DEFAULT)) - field.validate_in_choice('field', 1, field.choices) - - with pytest.raises(TypeError): - field.validate_in_choice('field', 'value', field.choices) - - -def test_field_without_default(): - class Test(object): - field = Field() - test = Test() - with pytest.raises(AttributeError, message="'Test' object has no attribute 'field'"): - assert test.field - - -def test_field_default_func(): - def true_func(): - return True - - field = Field(default=true_func) - assert vars(field) == vars(Field(type=None, choices=(), default=true_func)) - assert field.default - - -def test_field_default(): - field = Field(default='value') - assert vars(field) == vars(Field(type=None, choices=(), default='value')) - - -def test_iterable_field(): - iter_field = IterField(type=int) - assert vars(iter_field) == vars(Field(type=int, default=Field.NO_DEFAULT)) - iter_field.validate_value('iter_field', [1, 2]) - with pytest.raises(TypeError): - iter_field.validate_value('iter_field', ['a', 1]) - - -def test_pointer_field(): - test_model = model_factory() - - pointer_field = PointerField(type=Model) - assert vars(pointer_field) == \ - vars(PointerField(type=Model, choices=(), default=Field.NO_DEFAULT)) - with pytest.raises(AssertionError): - PointerField(type=list) - pointer_field.validate_value('pointer_field', test_model, None) - with pytest.raises(TypeError): - pointer_field.validate_value('pointer_field', int, None) - - -def test_iterable_pointer_field(): - test_model = model_factory() - iter_pointer_field = IterPointerField(type=Model) - assert vars(iter_pointer_field) == \ - vars(IterPointerField(type=Model, default=Field.NO_DEFAULT)) - with pytest.raises(AssertionError): - IterPointerField(type=list) - - iter_pointer_field.validate_value('iter_pointer_field', [test_model, test_model], None) - with pytest.raises(TypeError): - iter_pointer_field.validate_value('iter_pointer_field', [int, test_model], None) - - -def test_custom_field_validation(): - def validation_func(name, value, instance): - assert name == 'id' - assert value == 'value' - assert isinstance(instance, TestModel) - - class TestModel(Model): - id = Field(default='_', validation_func=validation_func) - - obj = TestModel() - obj.id = 'value' - - with pytest.raises(AssertionError): - obj.id = 'not_value' diff --git a/tests/storage/test_model_storage.py b/tests/storage/test_model_storage.py index 17e11aec..48cd02c0 100644 --- a/tests/storage/test_model_storage.py +++ b/tests/storage/test_model_storage.py @@ -16,78 +16,72 @@ import pytest from aria.storage import ( - Storage, ModelStorage, models, + exceptions, + sql_mapi, ) -from aria.storage import structures -from aria.storage.exceptions import StorageError -from aria.storage.structures import Model, Field, PointerField from aria import application_model_storage +from tests.storage import get_sqlite_api_kwargs, release_sqlite_storage -from . import InMemoryModelDriver +@pytest.fixture +def storage(): + base_storage = ModelStorage(sql_mapi.SQLAlchemyModelAPI, api_kwargs=get_sqlite_api_kwargs()) + yield base_storage + release_sqlite_storage(base_storage) -def test_storage_base(): - driver = InMemoryModelDriver() - storage = Storage(driver) - - assert storage.driver == driver +def test_storage_base(storage): with pytest.raises(AttributeError): storage.non_existent_attribute() -def test_model_storage(): - storage = ModelStorage(InMemoryModelDriver()) +def test_model_storage(storage): storage.register(models.ProviderContext) - storage.setup() - pc = models.ProviderContext(context={}, name='context_name', id='id1') - storage.provider_context.store(pc) + pc = models.ProviderContext(context={}, name='context_name') + storage.provider_context.put(pc) - assert storage.provider_context.get('id1') == pc + assert storage.provider_context.get_by_name('context_name') == pc assert [pc_from_storage for pc_from_storage in storage.provider_context.iter()] == [pc] assert [pc_from_storage for pc_from_storage in storage.provider_context] == [pc] - storage.provider_context.update('id1', context={'update_key': 0}) - assert storage.provider_context.get('id1').context == {'update_key': 0} + new_context = {'update_key': 0} + pc.context = new_context + storage.provider_context.update(pc) + assert storage.provider_context.get(pc.id).context == new_context - storage.provider_context.delete('id1') - with pytest.raises(StorageError): - storage.provider_context.get('id1') + storage.provider_context.delete(pc) + with pytest.raises(exceptions.StorageError): + storage.provider_context.get(pc.id) -def test_storage_driver(): - storage = ModelStorage(InMemoryModelDriver()) +def test_storage_driver(storage): storage.register(models.ProviderContext) - storage.setup() - pc = models.ProviderContext(context={}, name='context_name', id='id2') - storage.driver.store(name='provider_context', entry=pc.fields_dict, entry_id=pc.id) - assert storage.driver.get( - name='provider_context', - entry_id='id2', - model_cls=models.ProviderContext) == pc.fields_dict + pc = models.ProviderContext(context={}, name='context_name') + storage.registered['provider_context'].put(entry=pc) + + assert storage.registered['provider_context'].get_by_name('context_name') == pc - assert [i for i in storage.driver.iter(name='provider_context')] == [pc.fields_dict] + assert next(i for i in storage.registered['provider_context'].iter()) == pc assert [i for i in storage.provider_context] == [pc] - storage.provider_context.delete('id2') + storage.registered['provider_context'].delete(pc) - with pytest.raises(StorageError): - storage.provider_context.get('id2') + with pytest.raises(exceptions.StorageError): + storage.registered['provider_context'].get(pc.id) def test_application_storage_factory(): - driver = InMemoryModelDriver() - storage = application_model_storage(driver) + storage = application_model_storage(sql_mapi.SQLAlchemyModelAPI, + api_kwargs=get_sqlite_api_kwargs()) assert storage.node assert storage.node_instance assert storage.plugin assert storage.blueprint - assert storage.snapshot assert storage.deployment assert storage.deployment_update assert storage.deployment_update_step @@ -95,68 +89,4 @@ def test_application_storage_factory(): assert storage.execution assert storage.provider_context - reused_storage = application_model_storage(driver) - assert reused_storage == storage - - -def test_storage_pointers(): - class PointedModel(Model): - id = Field() - - class PointingModel(Model): - id = Field() - pointing_field = PointerField(type=PointedModel) - - storage = ModelStorage(InMemoryModelDriver(), model_classes=[PointingModel]) - storage.setup() - - assert storage.pointed_model - assert storage.pointing_model - - pointed_model = PointedModel(id='pointed_id') - - pointing_model = PointingModel(id='pointing_id', pointing_field=pointed_model) - storage.pointing_model.store(pointing_model) - - assert storage.pointed_model.get('pointed_id') == pointed_model - assert storage.pointing_model.get('pointing_id') == pointing_model - - storage.pointing_model.delete('pointing_id') - - with pytest.raises(StorageError): - assert storage.pointed_model.get('pointed_id') - assert storage.pointing_model.get('pointing_id') - - -def test_storage_iter_pointers(): - class PointedIterModel(models.Model): - id = structures.Field() - - class PointingIterModel(models.Model): - id = models.Field() - pointing_field = structures.IterPointerField(type=PointedIterModel) - - storage = ModelStorage(InMemoryModelDriver(), model_classes=[PointingIterModel]) - storage.setup() - - assert storage.pointed_iter_model - assert storage.pointing_iter_model - - pointed_iter_model1 = PointedIterModel(id='pointed_id1') - pointed_iter_model2 = PointedIterModel(id='pointed_id2') - - pointing_iter_model = PointingIterModel( - id='pointing_id', - pointing_field=[pointed_iter_model1, pointed_iter_model2]) - storage.pointing_iter_model.store(pointing_iter_model) - - assert storage.pointed_iter_model.get('pointed_id1') == pointed_iter_model1 - assert storage.pointed_iter_model.get('pointed_id2') == pointed_iter_model2 - assert storage.pointing_iter_model.get('pointing_id') == pointing_iter_model - - storage.pointing_iter_model.delete('pointing_id') - - with pytest.raises(StorageError): - assert storage.pointed_iter_model.get('pointed_id1') - assert storage.pointed_iter_model.get('pointed_id2') - assert storage.pointing_iter_model.get('pointing_id') + release_sqlite_storage(storage) diff --git a/tests/storage/test_models.py b/tests/storage/test_models.py index 7e289e6b..0ae5d1c5 100644 --- a/tests/storage/test_models.py +++ b/tests/storage/test_models.py @@ -12,353 +12,866 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import json +from contextlib import contextmanager from datetime import datetime - import pytest -from aria.storage import Model, Field -from aria.storage.exceptions import StorageError +from aria import application_model_storage +from aria.storage import exceptions +from aria.storage import sql_mapi from aria.storage.models import ( DeploymentUpdateStep, - Relationship, - RelationshipInstance, - Node, - NodeInstance, Blueprint, Execution, - Task + Task, + ProviderContext, + Plugin, + Deployment, + Node, + NodeInstance, + Relationship, + RelationshipInstance, + DeploymentUpdate, + DeploymentModification, ) -from tests.mock import models -# TODO: add tests per model +from tests import mock +from tests.storage import get_sqlite_api_kwargs, release_sqlite_storage -def test_base_model_without_fields(): - with pytest.raises(StorageError, message="Id field has to be in model fields"): - Model() +@contextmanager +def sql_storage(storage_func): + storage = None + try: + storage = storage_func() + yield storage + finally: + if storage: + release_sqlite_storage(storage) -def test_base_model_members(): - _test_field = Field() - class TestModel1(Model): - test_field = _test_field - id = Field(default='test_id') +def _empty_storage(): + return application_model_storage(sql_mapi.SQLAlchemyModelAPI, + api_kwargs=get_sqlite_api_kwargs()) - assert _test_field is TestModel1.test_field - test_model = TestModel1(test_field='test_field_value', id='test_id') +def _blueprint_storage(): + storage = _empty_storage() + blueprint = mock.models.get_blueprint() + storage.blueprint.put(blueprint) + return storage - assert repr(test_model) == "TestModel1(fields=['id', 'test_field'])" - expected = {'test_field': 'test_field_value', 'id': 'test_id'} - assert json.loads(test_model.json) == expected - assert test_model.fields_dict == expected - with pytest.raises(StorageError): - TestModel1() +def _deployment_storage(): + storage = _blueprint_storage() + deployment = mock.models.get_deployment(storage.blueprint.list()[0]) + storage.deployment.put(deployment) + return storage - with pytest.raises(StorageError): - TestModel1(test_field='test_field_value', id='test_id', unsupported_field='value') - class TestModel2(Model): - test_field = Field() - id = Field() +def _deployment_update_storage(): + storage = _deployment_storage() + deployment_update = DeploymentUpdate( + deployment_id=storage.deployment.list()[0].id, + created_at=now, + deployment_plan={}, + ) + storage.deployment_update.put(deployment_update) + return storage - with pytest.raises(StorageError): - TestModel2() +def _node_storage(): + storage = _deployment_storage() + node = mock.models.get_dependency_node(storage.deployment.list()[0]) + storage.node.put(node) + return storage -def test_blueprint_model(): - Blueprint( - plan={}, - id='id', - description='description', - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - main_file_name='/path', - ) - with pytest.raises(TypeError): - Blueprint( - plan=None, - id='id', - description='description', - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - main_file_name='/path', - ) - with pytest.raises(TypeError): - Blueprint( - plan={}, - id=999, - description='description', - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - main_file_name='/path', - ) - with pytest.raises(TypeError): - Blueprint( - plan={}, - id='id', - description=999, - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - main_file_name='/path', - ) - with pytest.raises(TypeError): - Blueprint( - plan={}, - id='id', - description='description', - created_at='error', - updated_at=datetime.utcnow(), - main_file_name='/path', - ) - with pytest.raises(TypeError): - Blueprint( - plan={}, - id='id', - description='description', - created_at=datetime.utcnow(), - updated_at=None, - main_file_name='/path', - ) - with pytest.raises(TypeError): - Blueprint( - plan={}, - id='id', - description='description', - created_at=datetime.utcnow(), - updated_at=None, - main_file_name=88, - ) - Blueprint( - plan={}, - description='description', - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - main_file_name='/path', - ) +def _nodes_storage(): + storage = _deployment_storage() + dependent_node = mock.models.get_dependent_node(storage.deployment.list()[0]) + dependency_node = mock.models.get_dependency_node(storage.deployment.list()[0]) + storage.node.put(dependent_node) + storage.node.put(dependency_node) + return storage -def test_deployment_update_step_model(): - add_node = DeploymentUpdateStep( - id='add_step', - action='add', - entity_type='node', - entity_id='node_id') - - modify_node = DeploymentUpdateStep( - id='modify_step', - action='modify', - entity_type='node', - entity_id='node_id') - - remove_node = DeploymentUpdateStep( - id='remove_step', - action='remove', - entity_type='node', - entity_id='node_id') - - for step in (add_node, modify_node, remove_node): - assert hash((step.id, step.entity_id)) == hash(step) - - assert remove_node < modify_node < add_node - assert not remove_node > modify_node > add_node - - add_rel = DeploymentUpdateStep( - id='add_step', - action='add', - entity_type='relationship', - entity_id='relationship_id') - - # modify_rel = DeploymentUpdateStep( - # id='modify_step', - # action='modify', - # entity_type='relationship', - # entity_id='relationship_id') - - remove_rel = DeploymentUpdateStep( - id='remove_step', - action='remove', - entity_type='relationship', - entity_id='relationship_id') - - assert remove_rel < remove_node < add_node < add_rel - assert not add_node < None - # TODO fix logic here so that pylint is happy - # assert not modify_node < modify_rel and not modify_rel < modify_node - - -def _relationship(id=''): - return Relationship( - id='rel{0}'.format(id), - target_id='target{0}'.format(id), - source_id='source{0}'.format(id), - source_interfaces={}, - source_operations={}, - target_interfaces={}, - target_operations={}, - type='type{0}'.format(id), - type_hierarchy=[], - properties={}) - - -def test_relationships(): - relationships = [_relationship(index) for index in xrange(3)] - - node = Node( - blueprint_id='blueprint_id', - type='type', - type_hierarchy=None, - number_of_instances=1, - planned_number_of_instances=1, - deploy_number_of_instances=1, - properties={}, - operations={}, - relationships=relationships, - min_number_of_instances=1, - max_number_of_instances=1) - - for index in xrange(3): - assert relationships[index] is \ - next(node.relationships_by_target('target{0}'.format(index))) - - relationship = _relationship() - - node = Node( - blueprint_id='blueprint_id', - type='type', - type_hierarchy=None, - number_of_instances=1, - planned_number_of_instances=1, - deploy_number_of_instances=1, - properties={}, - operations={}, - relationships=[relationship, relationship, relationship], - min_number_of_instances=1, - max_number_of_instances=1) - - for node_relationship in node.relationships_by_target('target'): - assert relationship is node_relationship - - -def test_relationship_instance(): - relationship = _relationship() - relationship_instances = [RelationshipInstance( - id='rel{0}'.format(index), - target_id='target_{0}'.format(index % 2), - source_id='source_{0}'.format(index % 2), - source_name='', - target_name='', - relationship=relationship, - type='type{0}'.format(index)) for index in xrange(3)] - - node_instance = NodeInstance( - deployment_id='deployment_id', - runtime_properties={}, - version='1', - relationship_instances=relationship_instances, - node=Node( - blueprint_id='blueprint_id', - type='type', - type_hierarchy=None, - number_of_instances=1, - planned_number_of_instances=1, - deploy_number_of_instances=1, - properties={}, - operations={}, - relationships=[], - min_number_of_instances=1, - max_number_of_instances=1), - scaling_groups=() - ) - from itertools import chain +def _node_instances_storage(): + storage = _nodes_storage() + dependent_node = storage.node.get_by_name(mock.models.DEPENDENT_NODE_NAME) + dependency_node = storage.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME) + dependency_node_instance = mock.models.get_dependency_node_instance(dependency_node) + dependent_node_instance = mock.models.get_dependent_node_instance(dependent_node) + storage.node_instance.put(dependency_node_instance) + storage.node_instance.put(dependent_node_instance) + return storage - assert set(relationship_instances) == set(chain( - node_instance.relationships_by_target('target_0'), - node_instance.relationships_by_target('target_1'))) +def _execution_storage(): + storage = _deployment_storage() + execution = mock.models.get_execution(storage.deployment.list()[0]) + storage.execution.put(execution) + return storage -def test_execution_status_transition(): - def create_execution(status): - return Execution( - id='e_id', - deployment_id='d_id', - workflow_id='w_id', - blueprint_id='b_id', - status=status, - parameters={} - ) - valid_transitions = { - Execution.PENDING: [Execution.STARTED, - Execution.CANCELLED, - Execution.PENDING], - Execution.STARTED: [Execution.FAILED, - Execution.TERMINATED, - Execution.CANCELLED, - Execution.CANCELLING, - Execution.STARTED], - Execution.CANCELLING: [Execution.FAILED, - Execution.TERMINATED, - Execution.CANCELLED, - Execution.CANCELLING], - Execution.FAILED: [Execution.FAILED], - Execution.TERMINATED: [Execution.TERMINATED], - Execution.CANCELLED: [Execution.CANCELLED] - } - - invalid_transitions = { - Execution.PENDING: [Execution.FAILED, - Execution.TERMINATED, - Execution.CANCELLING], - Execution.STARTED: [Execution.PENDING], - Execution.CANCELLING: [Execution.PENDING, - Execution.STARTED], - Execution.FAILED: [Execution.PENDING, - Execution.STARTED, - Execution.TERMINATED, - Execution.CANCELLED, - Execution.CANCELLING], - Execution.TERMINATED: [Execution.PENDING, +@pytest.fixture +def empty_storage(): + with sql_storage(_empty_storage) as storage: + yield storage + + +@pytest.fixture +def blueprint_storage(): + with sql_storage(_blueprint_storage) as storage: + yield storage + + +@pytest.fixture +def deployment_storage(): + with sql_storage(_deployment_storage) as storage: + yield storage + + +@pytest.fixture +def deployment_update_storage(): + with sql_storage(_deployment_update_storage) as storage: + yield storage + + +@pytest.fixture +def node_storage(): + with sql_storage(_node_storage) as storage: + yield storage + + +@pytest.fixture +def nodes_storage(): + with sql_storage(_nodes_storage) as storage: + yield storage + + +@pytest.fixture +def node_instances_storage(): + with sql_storage(_node_instances_storage) as storage: + yield storage + + +@pytest.fixture +def execution_storage(): + with sql_storage(_execution_storage) as storage: + yield storage + + +m_cls = type('MockClass') +now = datetime.utcnow() + + +def _test_model(is_valid, storage, model_name, model_cls, model_kwargs): + if is_valid: + model = model_cls(**model_kwargs) + getattr(storage, model_name).put(model) + return model + else: + with pytest.raises(exceptions.StorageError): + getattr(storage, model_name).put(model_cls(**model_kwargs)) + + +class TestBlueprint(object): + + @pytest.mark.parametrize( + 'is_valid, plan, description, created_at, updated_at, main_file_name', + [ + (False, None, 'description', now, now, '/path'), + (False, {}, {}, now, now, '/path'), + (False, {}, 'description', 'error', now, '/path'), + (False, {}, 'description', now, 'error', '/path'), + (False, {}, 'description', now, now, {}), + (True, {}, 'description', now, now, '/path'), + ] + ) + def test_blueprint_model_creation(self, empty_storage, is_valid, plan, description, created_at, + updated_at, main_file_name): + if not is_valid: + with pytest.raises(exceptions.StorageError): + empty_storage.blueprint.put(Blueprint(plan=plan, description=description, + created_at=created_at, updated_at=updated_at, + main_file_name=main_file_name)) + else: + empty_storage.blueprint.put(Blueprint(plan=plan, description=description, + created_at=created_at, updated_at=updated_at, + main_file_name=main_file_name)) + + +class TestDeployment(object): + + @pytest.mark.parametrize( + 'is_valid, name, created_at, description, inputs, groups, permalink, policy_triggers, ' + 'policy_types, outputs, scaling_groups, updated_at, workflows', + [ + (False, m_cls, now, 'desc', {}, {}, 'perlnk', {}, {}, {}, {}, now, {}), + (False, 'name', m_cls, 'desc', {}, {}, 'perlnk', {}, {}, {}, {}, now, {}), + (False, 'name', now, m_cls, {}, {}, 'perlnk', {}, {}, {}, {}, now, {}), + (False, 'name', now, 'desc', m_cls, {}, 'perlnk', {}, {}, {}, {}, now, {}), + (False, 'name', now, 'desc', {}, m_cls, 'perlnk', {}, {}, {}, {}, now, {}), + (False, 'name', now, 'desc', {}, {}, m_cls, {}, {}, {}, {}, now, {}), + (False, 'name', now, 'desc', {}, {}, 'perlnk', m_cls, {}, {}, {}, now, {}), + (False, 'name', now, 'desc', {}, {}, 'perlnk', {}, m_cls, {}, {}, now, {}), + (False, 'name', now, 'desc', {}, {}, 'perlnk', {}, {}, m_cls, {}, now, {}), + (False, 'name', now, 'desc', {}, {}, 'perlnk', {}, {}, {}, m_cls, now, {}), + (False, 'name', now, 'desc', {}, {}, 'perlnk', {}, {}, {}, {}, m_cls, {}), + (False, 'name', now, 'desc', {}, {}, 'perlnk', {}, {}, {}, {}, now, m_cls), + + (True, 'name', now, 'desc', {}, {}, 'perlnk', {}, {}, {}, {}, now, {}), + (True, None, now, 'desc', {}, {}, 'perlnk', {}, {}, {}, {}, now, {}), + (True, 'name', now, 'desc', {}, {}, 'perlnk', {}, {}, {}, {}, now, {}), + (True, 'name', now, None, {}, {}, 'perlnk', {}, {}, {}, {}, now, {}), + (True, 'name', now, 'desc', None, {}, 'perlnk', {}, {}, {}, {}, now, {}), + (True, 'name', now, 'desc', {}, None, 'perlnk', {}, {}, {}, {}, now, {}), + (True, 'name', now, 'desc', {}, {}, None, {}, {}, {}, {}, now, {}), + (True, 'name', now, 'desc', {}, {}, 'perlnk', None, {}, {}, {}, now, {}), + (True, 'name', now, 'desc', {}, {}, 'perlnk', {}, None, {}, {}, now, {}), + (True, 'name', now, 'desc', {}, {}, 'perlnk', {}, {}, None, {}, now, {}), + (True, 'name', now, 'desc', {}, {}, 'perlnk', {}, {}, {}, None, now, {}), + (True, 'name', now, 'desc', {}, {}, 'perlnk', {}, {}, {}, {}, None, {}), + (True, 'name', now, 'desc', {}, {}, 'perlnk', {}, {}, {}, {}, now, None), + ] + ) + def test_deployment_model_creation(self, deployment_storage, is_valid, name, created_at, + description, inputs, groups, permalink, policy_triggers, + policy_types, outputs, scaling_groups, updated_at, + workflows): + deployment = _test_model(is_valid=is_valid, + storage=deployment_storage, + model_name='deployment', + model_cls=Deployment, + model_kwargs=dict( + name=name, + blueprint_id=deployment_storage.blueprint.list()[0].id, + created_at=created_at, + description=description, + inputs=inputs, + groups=groups, + permalink=permalink, + policy_triggers=policy_triggers, + policy_types=policy_types, + outputs=outputs, + scaling_groups=scaling_groups, + updated_at=updated_at, + workflows=workflows + )) + if is_valid: + assert deployment.blueprint == deployment_storage.blueprint.list()[0] + + +class TestExecution(object): + + @pytest.mark.parametrize( + 'is_valid, created_at, started_at, ended_at, error, is_system_workflow, parameters, ' + 'status, workflow_name', + [ + (False, m_cls, now, now, 'error', False, {}, Execution.STARTED, 'wf_name'), + (False, now, m_cls, now, 'error', False, {}, Execution.STARTED, 'wf_name'), + (False, now, now, m_cls, 'error', False, {}, Execution.STARTED, 'wf_name'), + (False, now, now, now, m_cls, False, {}, Execution.STARTED, 'wf_name'), + (False, now, now, now, 'error', False, m_cls, Execution.STARTED, 'wf_name'), + (False, now, now, now, 'error', False, {}, m_cls, 'wf_name'), + (False, now, now, now, 'error', False, {}, Execution.STARTED, m_cls), + + (True, now, now, now, 'error', False, {}, Execution.STARTED, 'wf_name'), + (True, now, None, now, 'error', False, {}, Execution.STARTED, 'wf_name'), + (True, now, now, None, 'error', False, {}, Execution.STARTED, 'wf_name'), + (True, now, now, now, None, False, {}, Execution.STARTED, 'wf_name'), + (True, now, now, now, 'error', False, None, Execution.STARTED, 'wf_name'), + ] + ) + def test_execution_model_creation(self, deployment_storage, is_valid, created_at, started_at, + ended_at, error, is_system_workflow, parameters, status, + workflow_name): + execution = _test_model(is_valid=is_valid, + storage=deployment_storage, + model_name='execution', + model_cls=Execution, + model_kwargs=dict( + deployment_id=deployment_storage.deployment.list()[0].id, + blueprint_id=deployment_storage.blueprint.list()[0].id, + created_at=created_at, + started_at=started_at, + ended_at=ended_at, + error=error, + is_system_workflow=is_system_workflow, + parameters=parameters, + status=status, + workflow_name=workflow_name, + )) + if is_valid: + assert execution.deployment == deployment_storage.deployment.list()[0] + assert execution.blueprint == deployment_storage.blueprint.list()[0] + + def test_execution_status_transition(self): + def create_execution(status): + execution = Execution( + id='e_id', + workflow_name='w_name', + status=status, + parameters={}, + created_at=now, + ) + return execution + + valid_transitions = { + Execution.PENDING: [Execution.STARTED, + Execution.CANCELLED, + Execution.PENDING], + Execution.STARTED: [Execution.FAILED, + Execution.TERMINATED, + Execution.CANCELLED, + Execution.CANCELLING, + Execution.STARTED], + Execution.CANCELLING: [Execution.FAILED, + Execution.TERMINATED, + Execution.CANCELLED, + Execution.CANCELLING], + Execution.FAILED: [Execution.FAILED], + Execution.TERMINATED: [Execution.TERMINATED], + Execution.CANCELLED: [Execution.CANCELLED] + } + + invalid_transitions = { + Execution.PENDING: [Execution.FAILED, + Execution.TERMINATED, + Execution.CANCELLING], + Execution.STARTED: [Execution.PENDING], + Execution.CANCELLING: [Execution.PENDING, + Execution.STARTED], + Execution.FAILED: [Execution.PENDING, Execution.STARTED, - Execution.FAILED, + Execution.TERMINATED, Execution.CANCELLED, Execution.CANCELLING], - Execution.CANCELLED: [Execution.PENDING, - Execution.STARTED, - Execution.FAILED, - Execution.TERMINATED, - Execution.CANCELLING], - } - - for current_status, valid_transitioned_statues in valid_transitions.items(): - for transitioned_status in valid_transitioned_statues: - execution = create_execution(current_status) - execution.status = transitioned_status - - for current_status, invalid_transitioned_statues in invalid_transitions.items(): - for transitioned_status in invalid_transitioned_statues: - execution = create_execution(current_status) - with pytest.raises(ValueError): + Execution.TERMINATED: [Execution.PENDING, + Execution.STARTED, + Execution.FAILED, + Execution.CANCELLED, + Execution.CANCELLING], + Execution.CANCELLED: [Execution.PENDING, + Execution.STARTED, + Execution.FAILED, + Execution.TERMINATED, + Execution.CANCELLING], + } + + for current_status, valid_transitioned_statues in valid_transitions.items(): + for transitioned_status in valid_transitioned_statues: + execution = create_execution(current_status) execution.status = transitioned_status - -def test_task_max_attempts_validation(): - def create_task(max_attempts): - Task(execution_id='eid', - name='name', - operation_mapping='', - inputs={}, - actor=models.get_dependency_node_instance(), - max_attempts=max_attempts) - create_task(max_attempts=1) - create_task(max_attempts=2) - create_task(max_attempts=Task.INFINITE_RETRIES) - with pytest.raises(ValueError): - create_task(max_attempts=0) - with pytest.raises(ValueError): - create_task(max_attempts=-2) + for current_status, invalid_transitioned_statues in invalid_transitions.items(): + for transitioned_status in invalid_transitioned_statues: + execution = create_execution(current_status) + with pytest.raises(ValueError): + execution.status = transitioned_status + + +class TestDeploymentUpdate(object): + @pytest.mark.parametrize( + 'is_valid, created_at, deployment_plan, deployment_update_node_instances, ' + 'deployment_update_deployment, deployment_update_nodes, modified_entity_ids, state', + [ + (False, m_cls, {}, {}, {}, {}, {}, 'state'), + (False, now, m_cls, {}, {}, {}, {}, 'state'), + (False, now, {}, m_cls, {}, {}, {}, 'state'), + (False, now, {}, {}, m_cls, {}, {}, 'state'), + (False, now, {}, {}, {}, m_cls, {}, 'state'), + (False, now, {}, {}, {}, {}, m_cls, 'state'), + (False, now, {}, {}, {}, {}, {}, m_cls), + + (True, now, {}, {}, {}, {}, {}, 'state'), + (True, now, {}, None, {}, {}, {}, 'state'), + (True, now, {}, {}, None, {}, {}, 'state'), + (True, now, {}, {}, {}, None, {}, 'state'), + (True, now, {}, {}, {}, {}, None, 'state'), + (True, now, {}, {}, {}, {}, {}, None), + ] + ) + def test_deployment_update_model_creation(self, deployment_storage, is_valid, created_at, + deployment_plan, deployment_update_node_instances, + deployment_update_deployment, deployment_update_nodes, + modified_entity_ids, state): + deployment_update = _test_model( + is_valid=is_valid, + storage=deployment_storage, + model_name='deployment_update', + model_cls=DeploymentUpdate, + model_kwargs=dict( + deployment_id=deployment_storage.deployment.list()[0].id, + created_at=created_at, + deployment_plan=deployment_plan, + deployment_update_node_instances=deployment_update_node_instances, + deployment_update_deployment=deployment_update_deployment, + deployment_update_nodes=deployment_update_nodes, + modified_entity_ids=modified_entity_ids, + state=state, + )) + if is_valid: + assert deployment_update.deployment == deployment_storage.deployment.list()[0] + + +class TestDeploymentUpdateStep(object): + + @pytest.mark.parametrize( + 'is_valid, action, entity_id, entity_type', + [ + (False, m_cls, 'id', DeploymentUpdateStep.ENTITY_TYPES.NODE), + (False, DeploymentUpdateStep.ACTION_TYPES.ADD, m_cls, + DeploymentUpdateStep.ENTITY_TYPES.NODE), + (False, DeploymentUpdateStep.ACTION_TYPES.ADD, 'id', m_cls), + + (True, DeploymentUpdateStep.ACTION_TYPES.ADD, 'id', + DeploymentUpdateStep.ENTITY_TYPES.NODE) + ] + ) + def test_deployment_update_step_model_creation(self, deployment_update_storage, is_valid, + action, entity_id, entity_type): + deployment_update_step = _test_model( + is_valid=is_valid, + storage=deployment_update_storage, + model_name='deployment_update_step', + model_cls=DeploymentUpdateStep, + model_kwargs=dict( + deployment_update_id=deployment_update_storage.deployment_update.list()[0].id, + action=action, + entity_id=entity_id, + entity_type=entity_type + )) + if is_valid: + assert deployment_update_step.deployment_update == \ + deployment_update_storage.deployment_update.list()[0] + + def test_deployment_update_step_order(self): + add_node = DeploymentUpdateStep( + id='add_step', + action='add', + entity_type='node', + entity_id='node_id') + + modify_node = DeploymentUpdateStep( + id='modify_step', + action='modify', + entity_type='node', + entity_id='node_id') + + remove_node = DeploymentUpdateStep( + id='remove_step', + action='remove', + entity_type='node', + entity_id='node_id') + + for step in (add_node, modify_node, remove_node): + assert hash((step.id, step.entity_id)) == hash(step) + + assert remove_node < modify_node < add_node + assert not remove_node > modify_node > add_node + + add_rel = DeploymentUpdateStep( + id='add_step', + action='add', + entity_type='relationship', + entity_id='relationship_id') + + remove_rel = DeploymentUpdateStep( + id='remove_step', + action='remove', + entity_type='relationship', + entity_id='relationship_id') + + assert remove_rel < remove_node < add_node < add_rel + assert not add_node < None + + +class TestDeploymentModification(object): + @pytest.mark.parametrize( + 'is_valid, context, created_at, ended_at, modified_nodes, node_instances, status', + [ + (False, m_cls, now, now, {}, {}, DeploymentModification.STARTED), + (False, {}, m_cls, now, {}, {}, DeploymentModification.STARTED), + (False, {}, now, m_cls, {}, {}, DeploymentModification.STARTED), + (False, {}, now, now, m_cls, {}, DeploymentModification.STARTED), + (False, {}, now, now, {}, m_cls, DeploymentModification.STARTED), + (False, {}, now, now, {}, {}, m_cls), + + (True, {}, now, now, {}, {}, DeploymentModification.STARTED), + (True, {}, now, None, {}, {}, DeploymentModification.STARTED), + (True, {}, now, now, None, {}, DeploymentModification.STARTED), + (True, {}, now, now, {}, None, DeploymentModification.STARTED), + ] + ) + def test_deployment_modification_model_creation(self, deployment_storage, is_valid, context, + created_at, ended_at, modified_nodes, + node_instances, status): + deployment_modification = _test_model( + is_valid=is_valid, + storage=deployment_storage, + model_name='deployment_modification', + model_cls=DeploymentModification, + model_kwargs=dict( + deployment_id=deployment_storage.deployment.list()[0].id, + context=context, + created_at=created_at, + ended_at=ended_at, + modified_nodes=modified_nodes, + node_instances=node_instances, + status=status, + )) + if is_valid: + assert deployment_modification.deployment == deployment_storage.deployment.list()[0] + + +class TestNode(object): + @pytest.mark.parametrize( + 'is_valid, name, deploy_number_of_instances, max_number_of_instances, ' + 'min_number_of_instances, number_of_instances, planned_number_of_instances, plugins, ' + 'plugins_to_install, properties, operations, type, type_hierarchy', + [ + (False, m_cls, 1, 1, 1, 1, 1, {}, {}, {}, {}, 'type', []), + (False, 'name', m_cls, 1, 1, 1, 1, {}, {}, {}, {}, 'type', []), + (False, 'name', 1, m_cls, 1, 1, 1, {}, {}, {}, {}, 'type', []), + (False, 'name', 1, 1, m_cls, 1, 1, {}, {}, {}, {}, 'type', []), + (False, 'name', 1, 1, 1, m_cls, 1, {}, {}, {}, {}, 'type', []), + (False, 'name', 1, 1, 1, 1, m_cls, {}, {}, {}, {}, 'type', []), + (False, 'name', 1, 1, 1, 1, 1, m_cls, {}, {}, {}, 'type', []), + (False, 'name', 1, 1, 1, 1, 1, {}, m_cls, {}, {}, 'type', []), + (False, 'name', 1, 1, 1, 1, 1, {}, {}, m_cls, {}, 'type', []), + (False, 'name', 1, 1, 1, 1, 1, {}, {}, {}, m_cls, 'type', []), + (False, 'name', 1, 1, 1, 1, 1, {}, {}, {}, {}, m_cls, []), + (False, 'name', 1, 1, 1, 1, 1, {}, {}, {}, {}, 'type', m_cls), + + (True, 'name', 1, 1, 1, 1, 1, {}, {}, {}, {}, 'type', []), + (True, 'name', 1, 1, 1, 1, 1, None, {}, {}, {}, 'type', []), + (True, 'name', 1, 1, 1, 1, 1, {}, None, {}, {}, 'type', []), + (True, 'name', 1, 1, 1, 1, 1, {}, {}, None, {}, 'type', []), + (True, 'name', 1, 1, 1, 1, 1, {}, {}, {}, None, 'type', []), + (True, 'name', 1, 1, 1, 1, 1, {}, {}, {}, {}, 'type', []), + (True, 'name', 1, 1, 1, 1, 1, {}, {}, {}, {}, 'type', None), + ] + ) + def test_node_model_creation(self, deployment_storage, is_valid, name, + deploy_number_of_instances, max_number_of_instances, + min_number_of_instances, number_of_instances, + planned_number_of_instances, plugins, plugins_to_install, + properties, operations, type, type_hierarchy): + node = _test_model( + is_valid=is_valid, + storage=deployment_storage, + model_name='node', + model_cls=Node, + model_kwargs=dict( + name=name, + deploy_number_of_instances=deploy_number_of_instances, + max_number_of_instances=max_number_of_instances, + min_number_of_instances=min_number_of_instances, + number_of_instances=number_of_instances, + planned_number_of_instances=planned_number_of_instances, + plugins=plugins, + plugins_to_install=plugins_to_install, + properties=properties, + operations=operations, + type=type, + type_hierarchy=type_hierarchy, + deployment_id=deployment_storage.deployment.list()[0].id + )) + if is_valid: + assert node.deployment == deployment_storage.deployment.list()[0] + + +class TestRelationship(object): + @pytest.mark.parametrize( + 'is_valid, source_interfaces, source_operations, target_interfaces, target_operations, ' + 'type, type_hierarchy, properties', + [ + (False, m_cls, {}, {}, {}, 'type', [], {}), + (False, {}, m_cls, {}, {}, 'type', [], {}), + (False, {}, {}, m_cls, {}, 'type', [], {}), + (False, {}, {}, {}, m_cls, 'type', [], {}), + (False, {}, {}, {}, {}, m_cls, [], {}), + (False, {}, {}, {}, {}, 'type', m_cls, {}), + (False, {}, {}, {}, {}, 'type', [], m_cls), + + (True, {}, {}, {}, {}, 'type', [], {}), + (True, None, {}, {}, {}, 'type', [], {}), + (True, {}, {}, None, {}, 'type', [], {}), + (True, {}, {}, {}, {}, 'type', None, {}), + (True, {}, {}, {}, {}, 'type', [], None), + ] + ) + def test_relationship_model_ceration(self, nodes_storage, is_valid, source_interfaces, + source_operations, target_interfaces, target_operations, + type, type_hierarchy, properties): + relationship = _test_model( + is_valid=is_valid, + storage=nodes_storage, + model_name='relationship', + model_cls=Relationship, + model_kwargs=dict( + source_node_id=nodes_storage.node.list()[1].id, + target_node_id=nodes_storage.node.list()[0].id, + source_interfaces=source_interfaces, + source_operations=source_operations, + target_interfaces=target_interfaces, + target_operations=target_operations, + type=type, + type_hierarchy=type_hierarchy, + properties=properties, + )) + if is_valid: + assert relationship.source_node == nodes_storage.node.list()[1] + assert relationship.target_node == nodes_storage.node.list()[0] + + +class TestNodeInstance(object): + @pytest.mark.parametrize( + 'is_valid, name, runtime_properties, scaling_groups, state, version', + [ + (False, m_cls, {}, {}, 'state', 1), + (False, 'name', m_cls, {}, 'state', 1), + (False, 'name', {}, m_cls, 'state', 1), + (False, 'name', {}, {}, m_cls, 1), + (False, m_cls, {}, {}, 'state', m_cls), + + (True, 'name', {}, {}, 'state', 1), + (True, None, {}, {}, 'state', 1), + (True, 'name', None, {}, 'state', 1), + (True, 'name', {}, None, 'state', 1), + (True, 'name', {}, {}, 'state', None), + ] + ) + def test_node_instance_model_creation(self, node_storage, is_valid, name, runtime_properties, + scaling_groups, state, version): + node_instance = _test_model( + is_valid=is_valid, + storage=node_storage, + model_name='node_instance', + model_cls=NodeInstance, + model_kwargs=dict( + node_id=node_storage.node.list()[0].id, + deployment_id=node_storage.deployment.list()[0].id, + name=name, + runtime_properties=runtime_properties, + scaling_groups=scaling_groups, + state=state, + version=version, + )) + if is_valid: + assert node_instance.node == node_storage.node.list()[0] + assert node_instance.deployment == node_storage.deployment.list()[0] + + +class TestRelationshipInstance(object): + def test_relatiship_instance_model_creation(self, node_instances_storage): + relationship = mock.models.get_relationship( + source=node_instances_storage.node.get_by_name(mock.models.DEPENDENT_NODE_NAME), + target=node_instances_storage.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME) + ) + node_instances_storage.relationship.put(relationship) + node_instances = node_instances_storage.node_instance + source_node_instance = node_instances.get_by_name(mock.models.DEPENDENT_NODE_INSTANCE_NAME) + target_node_instance = node_instances.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + + relationship_instance = _test_model( + is_valid=True, + storage=node_instances_storage, + model_name='relationship_instance', + model_cls=RelationshipInstance, + model_kwargs=dict( + relationship_id=relationship.id, + source_node_instance_id=source_node_instance.id, + target_node_instance_id=target_node_instance.id + )) + assert relationship_instance.relationship == relationship + assert relationship_instance.source_node_instance == source_node_instance + assert relationship_instance.target_node_instance == target_node_instance + + +class TestProviderContext(object): + @pytest.mark.parametrize( + 'is_valid, name, context', + [ + (False, None, {}), + (False, 'name', None), + (True, 'name', {}), + ] + ) + def test_provider_context_model_creation(self, empty_storage, is_valid, name, context): + _test_model(is_valid=is_valid, + storage=empty_storage, + model_name='provider_context', + model_cls=ProviderContext, + model_kwargs=dict(name=name, context=context) + ) + + +class TestPlugin(object): + @pytest.mark.parametrize( + 'is_valid, archive_name, distribution, distribution_release, ' + 'distribution_version, excluded_wheels, package_name, package_source, ' + 'package_version, supported_platform, supported_py_versions, uploaded_at, wheels', + [ + (False, m_cls, 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', 'pak_ver', + {}, {}, now, {}), + (False, 'arc_name', m_cls, 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', 'pak_ver', + {}, {}, now, {}), + (False, 'arc_name', 'dis_name', m_cls, 'dis_ver', {}, 'pak_name', 'pak_src', 'pak_ver', + {}, {}, now, {}), + (False, 'arc_name', 'dis_name', 'dis_rel', m_cls, {}, 'pak_name', 'pak_src', 'pak_ver', + {}, {}, now, {}), + (False, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', m_cls, 'pak_name', 'pak_src', + 'pak_ver', {}, {}, now, {}), + (False, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, m_cls, 'pak_src', 'pak_ver', + {}, {}, now, {}), + (False, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', m_cls, 'pak_ver', + {}, {}, now, {}), + (False, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', m_cls, + {}, {}, now, {}), + (False, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', + 'pak_ver', m_cls, {}, now, {}), + (False, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', + 'pak_ver', {}, m_cls, now, {}), + (False, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', + 'pak_ver', {}, {}, m_cls, {}), + (False, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', + 'pak_ver', {}, {}, now, m_cls), + + (True, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', + 'pak_ver', {}, {}, now, {}), + (True, 'arc_name', None, 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', 'pak_ver', + {}, {}, now, {}), + (True, 'arc_name', 'dis_name', None, 'dis_ver', {}, 'pak_name', 'pak_src', 'pak_ver', + {}, {}, now, {}), + (True, 'arc_name', 'dis_name', 'dis_rel', None, {}, 'pak_name', 'pak_src', 'pak_ver', + {}, {}, now, {}), + (True, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', None, 'pak_name', 'pak_src', + 'pak_ver', {}, {}, now, {}), + (True, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', None, 'pak_ver', + {}, {}, now, {}), + (True, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', None, + {}, {}, now, {}), + (True, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', + 'pak_ver', None, {}, now, {}), + (True, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', + 'pak_ver', {}, None, now, {}), + (True, 'arc_name', 'dis_name', 'dis_rel', 'dis_ver', {}, 'pak_name', 'pak_src', + 'pak_ver', {}, {}, now, {}), + ] + ) + def test_plugin_model_creation(self, empty_storage, is_valid, archive_name, distribution, + distribution_release, distribution_version, excluded_wheels, + package_name, package_source, package_version, + supported_platform, supported_py_versions, uploaded_at, wheels): + _test_model(is_valid=is_valid, + storage=empty_storage, + model_name='plugin', + model_cls=Plugin, + model_kwargs=dict( + archive_name=archive_name, + distribution=distribution, + distribution_release=distribution_release, + distribution_version=distribution_version, + excluded_wheels=excluded_wheels, + package_name=package_name, + package_source=package_source, + package_version=package_version, + supported_platform=supported_platform, + supported_py_versions=supported_py_versions, + uploaded_at=uploaded_at, + wheels=wheels, + )) + + +class TestTask(object): + + @pytest.mark.parametrize( + 'is_valid, status, due_at, started_at, ended_at, max_attempts, retry_count, ' + 'retry_interval, ignore_failure, name, operation_mapping, inputs', + [ + (False, m_cls, now, now, now, 1, 1, 1, True, 'name', 'map', {}), + (False, Task.STARTED, m_cls, now, now, 1, 1, 1, True, 'name', 'map', {}), + (False, Task.STARTED, now, m_cls, now, 1, 1, 1, True, 'name', 'map', {}), + (False, Task.STARTED, now, now, m_cls, 1, 1, 1, True, 'name', 'map', {}), + (False, Task.STARTED, now, now, now, m_cls, 1, 1, True, 'name', 'map', {}), + (False, Task.STARTED, now, now, now, 1, m_cls, 1, True, 'name', 'map', {}), + (False, Task.STARTED, now, now, now, 1, 1, m_cls, True, 'name', 'map', {}), + (False, Task.STARTED, now, now, now, 1, 1, 1, True, m_cls, 'map', {}), + (False, Task.STARTED, now, now, now, 1, 1, 1, True, 'name', m_cls, {}), + (False, Task.STARTED, now, now, now, 1, 1, 1, True, 'name', 'map', m_cls), + + (True, Task.STARTED, now, now, now, 1, 1, 1, True, 'name', 'map', {}), + (True, Task.STARTED, None, now, now, 1, 1, 1, True, 'name', 'map', {}), + (True, Task.STARTED, now, None, now, 1, 1, 1, True, 'name', 'map', {}), + (True, Task.STARTED, now, now, None, 1, 1, 1, True, 'name', 'map', {}), + (True, Task.STARTED, now, now, now, 1, None, 1, True, 'name', 'map', {}), + (True, Task.STARTED, now, now, now, 1, 1, None, True, 'name', 'map', {}), + (True, Task.STARTED, now, now, now, 1, 1, 1, None, 'name', 'map', {}), + (True, Task.STARTED, now, now, now, 1, 1, 1, True, None, 'map', {}), + (True, Task.STARTED, now, now, now, 1, 1, 1, True, 'name', None, {}), + (True, Task.STARTED, now, now, now, 1, 1, 1, True, 'name', 'map', None), + ] + ) + def test_task_model_creation(self, execution_storage, is_valid, status, due_at, started_at, + ended_at, max_attempts, retry_count, retry_interval, + ignore_failure, name, operation_mapping, inputs): + task = _test_model( + is_valid=is_valid, + storage=execution_storage, + model_name='task', + model_cls=Task, + model_kwargs=dict( + status=status, + execution_id=execution_storage.execution.list()[0].id, + due_at=due_at, + started_at=started_at, + ended_at=ended_at, + max_attempts=max_attempts, + retry_count=retry_count, + retry_interval=retry_interval, + ignore_failure=ignore_failure, + name=name, + operation_mapping=operation_mapping, + inputs=inputs, + )) + if is_valid: + assert task.execution == execution_storage.execution.list()[0] + + def test_task_max_attempts_validation(self): + def create_task(max_attempts): + Task(execution_id='eid', + name='name', + operation_mapping='', + inputs={}, + max_attempts=max_attempts) + create_task(max_attempts=1) + create_task(max_attempts=2) + create_task(max_attempts=Task.INFINITE_RETRIES) + with pytest.raises(ValueError): + create_task(max_attempts=0) + with pytest.raises(ValueError): + create_task(max_attempts=-2) + + +def test_inner_dict_update(empty_storage): + inner_dict = {'inner_value': 1} + pc = ProviderContext(name='name', context={ + 'inner_dict': {'inner_value': inner_dict}, + 'value': 0 + }) + empty_storage.provider_context.put(pc) + + storage_pc = empty_storage.provider_context.get(pc.id) + assert storage_pc == pc + + storage_pc.context['inner_dict']['inner_value'] = 2 + storage_pc.context['value'] = -1 + empty_storage.provider_context.update(storage_pc) + storage_pc = empty_storage.provider_context.get(pc.id) + + assert storage_pc.context['inner_dict']['inner_value'] == 2 + assert storage_pc.context['value'] == -1 diff --git a/tests/storage/test_models_api.py b/tests/storage/test_models_api.py deleted file mode 100644 index 2b928205..00000000 --- a/tests/storage/test_models_api.py +++ /dev/null @@ -1,70 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from aria.storage import _ModelApi, models -from aria.storage.exceptions import StorageError - -from . import InMemoryModelDriver - - -def test_models_api_base(): - driver = InMemoryModelDriver() - driver.create('provider_context') - table = _ModelApi('provider_context', driver, models.ProviderContext) - assert repr(table) == ( - '{table.name}(driver={table.driver}, ' - 'model={table.model_cls})'.format(table=table)) - provider_context = models.ProviderContext(context={}, name='context_name', id='id') - - table.store(provider_context) - assert table.get('id') == provider_context - - assert [i for i in table.iter()] == [provider_context] - assert [i for i in table] == [provider_context] - - table.delete('id') - - with pytest.raises(StorageError): - table.get('id') - - -def test_iterable_model_api(): - driver = InMemoryModelDriver() - driver.create('deployment_update') - driver.create('deployment_update_step') - model_api = _ModelApi('deployment_update', driver, models.DeploymentUpdate) - deployment_update = models.DeploymentUpdate( - id='id', - deployment_id='deployment_id', - deployment_plan={}, - execution_id='execution_id', - steps=[models.DeploymentUpdateStep( - id='step_id', - action='add', - entity_type='node', - entity_id='node_id' - )] - ) - - model_api.store(deployment_update) - assert [i for i in model_api.iter()] == [deployment_update] - assert [i for i in model_api] == [deployment_update] - - model_api.delete('id') - - with pytest.raises(StorageError): - model_api.get('id') diff --git a/tests/storage/test_resource_storage.py b/tests/storage/test_resource_storage.py index 918b2704..9b5f7826 100644 --- a/tests/storage/test_resource_storage.py +++ b/tests/storage/test_resource_storage.py @@ -1,4 +1,4 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more +# Licensed to the Apache ftware Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 @@ -18,15 +18,17 @@ import pytest -from aria.storage.exceptions import StorageError -from aria.storage import ResourceStorage, FileSystemResourceDriver +from aria.storage.filesystem_rapi import FileSystemResourceAPI +from aria.storage import ( + exceptions, + ResourceStorage +) from . import TestFileSystem class TestResourceStorage(TestFileSystem): def _create(self, storage): storage.register('blueprint') - storage.setup() def _upload(self, storage, tmp_path, id): with open(tmp_path, 'w') as f: @@ -41,24 +43,26 @@ def _upload_dir(self, storage, tmp_dir, tmp_file_name, id): storage.blueprint.upload(entry_id=id, source=tmp_dir) + def _create_storage(self): + return ResourceStorage(FileSystemResourceAPI, + api_kwargs=dict(directory=self.path)) + def test_name(self): - driver = FileSystemResourceDriver(directory=self.path) - storage = ResourceStorage(driver, resources=['blueprint']) - assert repr(storage) == 'ResourceStorage(driver={driver})'.format( - driver=driver - ) - assert repr(storage.registered['blueprint']) == ( - 'ResourceApi(driver={driver}, resource={resource_name})'.format( - driver=driver, - resource_name='blueprint')) + api = FileSystemResourceAPI + storage = ResourceStorage(FileSystemResourceAPI, + items=['blueprint'], + api_kwargs=dict(directory=self.path)) + assert repr(storage) == 'ResourceStorage(api={api})'.format(api=api) + assert 'directory={resource_dir}'.format(resource_dir=self.path) in \ + repr(storage.registered['blueprint']) def test_create(self): - storage = ResourceStorage(FileSystemResourceDriver(directory=self.path)) + storage = self._create_storage() self._create(storage) assert os.path.exists(os.path.join(self.path, 'blueprint')) def test_upload_file(self): - storage = ResourceStorage(FileSystemResourceDriver(directory=self.path)) + storage = ResourceStorage(FileSystemResourceAPI, api_kwargs=dict(directory=self.path)) self._create(storage) tmpfile_path = tempfile.mkstemp(suffix=self.__class__.__name__, dir=self.path)[1] self._upload(storage, tmpfile_path, id='blueprint_id') @@ -74,7 +78,7 @@ def test_upload_file(self): assert f.read() == 'fake context' def test_download_file(self): - storage = ResourceStorage(FileSystemResourceDriver(directory=self.path)) + storage = self._create_storage() self._create(storage) tmpfile_path = tempfile.mkstemp(suffix=self.__class__.__name__, dir=self.path)[1] tmpfile_name = os.path.basename(tmpfile_path) @@ -90,27 +94,27 @@ def test_download_file(self): assert f.read() == 'fake context' def test_download_non_existing_file(self): - storage = ResourceStorage(FileSystemResourceDriver(directory=self.path)) + storage = self._create_storage() self._create(storage) - with pytest.raises(StorageError): + with pytest.raises(exceptions.StorageError): storage.blueprint.download(entry_id='blueprint_id', destination='', path='fake_path') def test_data_non_existing_file(self): - storage = ResourceStorage(FileSystemResourceDriver(directory=self.path)) + storage = self._create_storage() self._create(storage) - with pytest.raises(StorageError): - storage.blueprint.data(entry_id='blueprint_id', path='fake_path') + with pytest.raises(exceptions.StorageError): + storage.blueprint.read(entry_id='blueprint_id', path='fake_path') def test_data_file(self): - storage = ResourceStorage(FileSystemResourceDriver(directory=self.path)) + storage = self._create_storage() self._create(storage) tmpfile_path = tempfile.mkstemp(suffix=self.__class__.__name__, dir=self.path)[1] self._upload(storage, tmpfile_path, 'blueprint_id') - assert storage.blueprint.data(entry_id='blueprint_id') == 'fake context' + assert storage.blueprint.read(entry_id='blueprint_id') == 'fake context' def test_upload_dir(self): - storage = ResourceStorage(FileSystemResourceDriver(directory=self.path)) + storage = self._create_storage() self._create(storage) tmp_dir = tempfile.mkdtemp(suffix=self.__class__.__name__, dir=self.path) second_level_tmp_dir = tempfile.mkdtemp(dir=tmp_dir) @@ -127,7 +131,7 @@ def test_upload_dir(self): assert os.path.isfile(destination) def test_upload_path_in_dir(self): - storage = ResourceStorage(FileSystemResourceDriver(directory=self.path)) + storage = self._create_storage() self._create(storage) tmp_dir = tempfile.mkdtemp(suffix=self.__class__.__name__, dir=self.path) second_level_tmp_dir = tempfile.mkdtemp(dir=tmp_dir) @@ -151,7 +155,7 @@ def test_upload_path_in_dir(self): os.path.basename(second_update_file))) def test_download_dir(self): - storage = ResourceStorage(FileSystemResourceDriver(directory=self.path)) + storage = self._create_storage() self._create(storage) tmp_dir = tempfile.mkdtemp(suffix=self.__class__.__name__, dir=self.path) second_level_tmp_dir = tempfile.mkdtemp(dir=tmp_dir) @@ -174,7 +178,7 @@ def test_download_dir(self): assert f.read() == 'fake context' def test_data_dir(self): - storage = ResourceStorage(FileSystemResourceDriver(directory=self.path)) + storage = self._create_storage() self._create(storage) tmp_dir = tempfile.mkdtemp(suffix=self.__class__.__name__, dir=self.path) @@ -183,5 +187,5 @@ def test_data_dir(self): storage.blueprint.upload(entry_id='blueprint_id', source=tmp_dir) - with pytest.raises(StorageError): - storage.blueprint.data(entry_id='blueprint_id') + with pytest.raises(exceptions.StorageError): + storage.blueprint.read(entry_id='blueprint_id')