From c3906c62280e14b3b8e07605bd9e45e635e2669d Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 25 Nov 2019 23:10:02 +0100 Subject: [PATCH] Make `Group` sub classable through entry points We add the `aiida.groups` entry point group where sub classes of the `aiida.orm.groups.Group` class can be registered. A new metaclass is used to automatically set the `type_string` based on the entry point of the `Group` sub class. This will make it possible to reload the correct sub class when reloading from the database. If the `GroupMeta` metclass fails cannot retrieve the corresponding entry point of the subclass, a warning is issued that any instances of this class will not be storable and the `__type_string` attribute is set to `None`. This can be checked by the `store` method which will make it fail. We choose to only except in the `store` method such that it is still possible to define and instantiate subclasses of `Group` that have not yet been registered. This is useful for testing and experimenting. Since the group type strings are now based on the entry point names, the existing group type strings in the database have to be migrated: * `user` -> `core` * `data.upf.family` -> `core.upf` * `auto.import` -> `core.import` * `auto.run` -> `core.run` When loading a `Group` instance from the database, the loader will try to resolve the type string to the corresponding subclass through the entry points. If this fails, a warning is issued and we fallback on the base `Group` class. --- .../db/migrations/0044_dbgroup_type_string.py | 44 ++++ .../backends/djsite/db/migrations/__init__.py | 2 +- .../bf591f31dd12_dbgroup_type_string.py | 45 ++++ aiida/cmdline/commands/cmd_data/cmd_upf.py | 11 +- aiida/cmdline/commands/cmd_group.py | 22 +- aiida/cmdline/commands/cmd_run.py | 1 + aiida/cmdline/params/types/group.py | 4 +- aiida/orm/autogroup.py | 33 +-- aiida/orm/convert.py | 5 +- aiida/orm/groups.py | 108 ++++++--- aiida/orm/implementation/groups.py | 2 +- aiida/orm/nodes/data/upf.py | 30 +-- aiida/plugins/entry_point.py | 2 + aiida/plugins/factories.py | 23 +- aiida/tools/importexport/common/config.py | 4 +- .../dbimport/backends/django/__init__.py | 6 +- .../dbimport/backends/sqla/__init__.py | 6 +- setup.json | 6 + ...est_migrations_0044_dbgroup_type_string.py | 63 +++++ .../aiida_sqlalchemy/test_migrations.py | 66 ++++++ tests/cmdline/commands/test_group.py | 2 +- tests/cmdline/commands/test_run.py | 223 ++++++++++-------- tests/orm/data/test_upf.py | 16 +- tests/orm/test_groups.py | 53 +++++ tests/tools/graph/test_age.py | 6 +- .../tools/importexport/test_prov_redesign.py | 4 +- 26 files changed, 554 insertions(+), 233 deletions(-) create mode 100644 aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py create mode 100644 aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py create mode 100644 tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py diff --git a/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py b/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py new file mode 100644 index 00000000000..c515c8a1462 --- /dev/null +++ b/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,too-few-public-methods +"""Migration after the `Group` class became pluginnable and so the group `type_string` changed.""" + +# pylint: disable=no-name-in-module,import-error +from django.db import migrations +from aiida.backends.djsite.db.migrations import upgrade_schema_version + +REVISION = '1.0.44' +DOWN_REVISION = '1.0.43' + +forward_sql = [ + """UPDATE db_dbgroup SET type_string = 'core' WHERE type_string = 'user';""", + """UPDATE db_dbgroup SET type_string = 'core.upf' WHERE type_string = 'data.upf';""", + """UPDATE db_dbgroup SET type_string = 'core.import' WHERE type_string = 'auto.import';""", + """UPDATE db_dbgroup SET type_string = 'core.run' WHERE type_string = 'auto.run';""", +] + +reverse_sql = [ + """UPDATE db_dbgroup SET type_string = 'user' WHERE type_string = 'core';""", + """UPDATE db_dbgroup SET type_string = 'data.upf' WHERE type_string = 'core.upf';""", + """UPDATE db_dbgroup SET type_string = 'auto.import' WHERE type_string = 'core.import';""", + """UPDATE db_dbgroup SET type_string = 'auto.run' WHERE type_string = 'core.run';""", +] + + +class Migration(migrations.Migration): + """Migration after the update of group `type_string`""" + dependencies = [ + ('db', '0043_default_link_label'), + ] + + operations = [ + migrations.RunSQL(sql='\n'.join(forward_sql), reverse_sql='\n'.join(reverse_sql)), + upgrade_schema_version(REVISION, DOWN_REVISION), + ] diff --git a/aiida/backends/djsite/db/migrations/__init__.py b/aiida/backends/djsite/db/migrations/__init__.py index a832b4e5f7d..41ee2b3d2ce 100644 --- a/aiida/backends/djsite/db/migrations/__init__.py +++ b/aiida/backends/djsite/db/migrations/__init__.py @@ -21,7 +21,7 @@ class DeserializationException(AiidaException): pass -LATEST_MIGRATION = '0043_default_link_label' +LATEST_MIGRATION = '0044_dbgroup_type_string' def _update_schema_version(version, apps, _): diff --git a/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py b/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py new file mode 100644 index 00000000000..7c6a7063d10 --- /dev/null +++ b/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +"""Migration after the `Group` class became pluginnable and so the group `type_string` changed. + +Revision ID: bf591f31dd12 +Revises: 118349c10896 +Create Date: 2020-03-31 10:00:52.609146 + +""" +# pylint: disable=no-name-in-module,import-error,invalid-name,no-member +from alembic import op +from sqlalchemy.sql import text + +forward_sql = [ + """UPDATE db_dbgroup SET type_string = 'core' WHERE type_string = 'user';""", + """UPDATE db_dbgroup SET type_string = 'core.upf' WHERE type_string = 'data.upf';""", + """UPDATE db_dbgroup SET type_string = 'core.import' WHERE type_string = 'auto.import';""", + """UPDATE db_dbgroup SET type_string = 'core.run' WHERE type_string = 'auto.run';""", +] + +reverse_sql = [ + """UPDATE db_dbgroup SET type_string = 'user' WHERE type_string = 'core';""", + """UPDATE db_dbgroup SET type_string = 'data.upf' WHERE type_string = 'core.upf';""", + """UPDATE db_dbgroup SET type_string = 'auto.import' WHERE type_string = 'core.import';""", + """UPDATE db_dbgroup SET type_string = 'auto.run' WHERE type_string = 'core.run';""", +] + +# revision identifiers, used by Alembic. +revision = 'bf591f31dd12' +down_revision = '118349c10896' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + statement = text('\n'.join(forward_sql)) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + conn = op.get_bind() + statement = text('\n'.join(reverse_sql)) + conn.execute(statement) diff --git a/aiida/cmdline/commands/cmd_data/cmd_upf.py b/aiida/cmdline/commands/cmd_data/cmd_upf.py index 78f79b0d9e8..745f4af7a2e 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_upf.py +++ b/aiida/cmdline/commands/cmd_data/cmd_upf.py @@ -64,22 +64,13 @@ def upf_listfamilies(elements, with_description): """ from aiida import orm from aiida.plugins import DataFactory - from aiida.orm.nodes.data.upf import UPFGROUP_TYPE UpfData = DataFactory('upf') # pylint: disable=invalid-name query = orm.QueryBuilder() query.append(UpfData, tag='upfdata') if elements is not None: query.add_filter(UpfData, {'attributes.element': {'in': elements}}) - query.append( - orm.Group, - with_node='upfdata', - tag='group', - project=['label', 'description'], - filters={'type_string': { - '==': UPFGROUP_TYPE - }} - ) + query.append(orm.UpfFamily, with_node='upfdata', tag='group', project=['label', 'description']) query.distinct() if query.count() > 0: diff --git a/aiida/cmdline/commands/cmd_group.py b/aiida/cmdline/commands/cmd_group.py index d74e416bd58..ea52028c3f2 100644 --- a/aiida/cmdline/commands/cmd_group.py +++ b/aiida/cmdline/commands/cmd_group.py @@ -13,7 +13,7 @@ from aiida.common.exceptions import UniquenessError from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options, arguments, types +from aiida.cmdline.params import options, arguments from aiida.cmdline.utils import echo from aiida.cmdline.utils.decorators import with_dbenv @@ -178,18 +178,6 @@ def group_show(group, raw, limit, uuid): echo.echo(tabulate(table, headers=header)) -@with_dbenv() -def valid_group_type_strings(): - from aiida.orm import GroupTypeString - return tuple(i.value for i in GroupTypeString) - - -@with_dbenv() -def user_defined_group(): - from aiida.orm import GroupTypeString - return GroupTypeString.USER.value - - @verdi_group.command('list') @options.ALL_USERS(help='Show groups for all users, rather than only for the current user') @click.option( @@ -204,8 +192,7 @@ def user_defined_group(): '-t', '--type', 'group_type', - type=types.LazyChoice(valid_group_type_strings), - default=user_defined_group, + default='core', help='Show groups of a specific type, instead of user-defined groups. Start with semicolumn if you want to ' 'specify aiida-internal type' ) @@ -330,9 +317,8 @@ def group_list( def group_create(group_label): """Create an empty group with a given name.""" from aiida import orm - from aiida.orm import GroupTypeString - group, created = orm.Group.objects.get_or_create(label=group_label, type_string=GroupTypeString.USER.value) + group, created = orm.Group.objects.get_or_create(label=group_label) if created: echo.echo_success("Group created with PK = {} and name '{}'".format(group.id, group.label)) @@ -351,7 +337,7 @@ def group_copy(source_group, destination_group): Note that the destination group may not exist.""" from aiida import orm - dest_group, created = orm.Group.objects.get_or_create(label=destination_group, type_string=source_group.type_string) + dest_group, created = orm.Group.objects.get_or_create(label=destination_group) # Issue warning if destination group is not empty and get user confirmation to continue if not created and not dest_group.is_empty: diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index bd3972b8418..d46b6f984cc 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -150,5 +150,6 @@ def run(scriptname, varargs, auto_group, auto_group_label_prefix, group_name, ex # Re-raise the exception to have the error code properly returned at the end raise finally: + autogroup.current_autogroup = None if handle: handle.close() diff --git a/aiida/cmdline/params/types/group.py b/aiida/cmdline/params/types/group.py index ef216044e71..6150f6d0625 100644 --- a/aiida/cmdline/params/types/group.py +++ b/aiida/cmdline/params/types/group.py @@ -40,12 +40,12 @@ def orm_class_loader(self): @with_dbenv() def convert(self, value, param, ctx): - from aiida.orm import Group, GroupTypeString + from aiida.orm import Group try: group = super().convert(value, param, ctx) except click.BadParameter: if self._create_if_not_exist: - group = Group(label=value, type_string=GroupTypeString.USER.value) + group = Group(label=value) else: raise diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index 16bf03f1c1c..06e83185e35 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -14,21 +14,18 @@ from aiida.common import exceptions, timezone from aiida.common.escaping import escape_for_sql_like, get_regex_pattern_from_sql from aiida.common.warnings import AiidaDeprecationWarning -from aiida.orm import GroupTypeString, Group +from aiida.orm import AutoGroup from aiida.plugins.entry_point import get_entry_point_string_from_class CURRENT_AUTOGROUP = None -VERDIAUTOGROUP_TYPE = GroupTypeString.VERDIAUTOGROUP_TYPE.value - class Autogroup: - """ - An object used for the autogrouping of objects. - The autogrouping is checked by the Node.store() method. - In the store(), the Node will check if CURRENT_AUTOGROUP is != None. - If so, it will call Autogroup.is_to_be_grouped, and decide whether to put it in a group. - Such autogroups are going to be of the VERDIAUTOGROUP_TYPE. + """Class to create a new `AutoGroup` instance that will, while active, automatically contain all nodes being stored. + + The autogrouping is checked by the `Node.store()` method which, if `CURRENT_AUTOGROUP is not None` the method + `Autogroup.is_to_be_grouped` is called to decide whether to put the current node being stored in the current + `AutoGroup` instance. The exclude/include lists are lists of strings like: ``aiida.data:int``, ``aiida.calculation:quantumespresso.pw``, @@ -198,7 +195,7 @@ def clear_group_cache(self): self._group_label = None def get_or_create_group(self): - """Return the current Autogroup, or create one if None has been set yet. + """Return the current `AutoGroup`, or create one if None has been set yet. This function implements a somewhat complex logic that is however needed to make sure that, even if `verdi run` is called at the same time multiple @@ -219,16 +216,10 @@ def get_or_create_group(self): # So the group with the same name can be returned quickly in future # calls of this method. if self._group_label is not None: - results = [ - res[0] for res in QueryBuilder(). - append(Group, filters={ - 'label': self._group_label, - 'type_string': VERDIAUTOGROUP_TYPE - }, project='*').iterall() - ] + builder = QueryBuilder().append(AutoGroup, filters={'label': self._group_label}) + results = [res[0] for res in builder.iterall()] if results: - # If it is not empty, it should have only one result due to the - # uniqueness constraints + # If it is not empty, it should have only one result due to the uniqueness constraints assert len(results) == 1, 'I got more than one autogroup with the same label!' return results[0] # There are no results: probably the group has been deleted. @@ -239,7 +230,7 @@ def get_or_create_group(self): # Try to do a preliminary QB query to avoid to do too many try/except # if many of the prefix_NUMBER groups already exist queryb = QueryBuilder().append( - Group, + AutoGroup, filters={ 'or': [{ 'label': { @@ -274,7 +265,7 @@ def get_or_create_group(self): while True: try: label = label_prefix if counter == 0 else '{}_{}'.format(label_prefix, counter) - group = Group(label=label, type_string=VERDIAUTOGROUP_TYPE).store() + group = AutoGroup(label=label).store() self._group_label = group.label except exceptions.IntegrityError: counter += 1 diff --git a/aiida/orm/convert.py b/aiida/orm/convert.py index 197253cffda..d6b577773b4 100644 --- a/aiida/orm/convert.py +++ b/aiida/orm/convert.py @@ -61,8 +61,9 @@ def _(backend_entity): @get_orm_entity.register(BackendGroup) def _(backend_entity): - from . import groups - return groups.Group.from_backend_entity(backend_entity) + from .groups import load_group_class + group_class = load_group_class(backend_entity.type_string) + return group_class.from_backend_entity(backend_entity) @get_orm_entity.register(BackendComputer) diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py index cb7b4af801e..6075653e87a 100644 --- a/aiida/orm/groups.py +++ b/aiida/orm/groups.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """ AiiDA Group entites""" - +from abc import ABCMeta from enum import Enum import warnings @@ -21,19 +21,63 @@ from . import entities from . import users -__all__ = ('Group', 'GroupTypeString') +__all__ = ('Group', 'GroupTypeString', 'AutoGroup', 'ImportGroup', 'UpfFamily') + + +def load_group_class(type_string): + """Load the sub class of `Group` that corresponds to the given `type_string`. + + .. note:: will fall back on `aiida.orm.groups.Group` if `type_string` cannot be resolved to loadable entry point. + + :param type_string: the entry point name of the `Group` sub class + :return: sub class of `Group` registered through an entry point + """ + from aiida.common.exceptions import EntryPointError + from aiida.plugins.entry_point import load_entry_point + + try: + group_class = load_entry_point('aiida.groups', type_string) + except EntryPointError: + message = 'could not load entry point `{}`, falling back onto `Group` base class.'.format(type_string) + warnings.warn(message) # pylint: disable=no-member + group_class = Group + + return group_class + + +class GroupMeta(ABCMeta): + """Meta class for `aiida.orm.groups.Group` to automatically set the `type_string` attribute.""" + + def __new__(mcs, name, bases, namespace, **kwargs): + from aiida.plugins.entry_point import get_entry_point_from_class + + newcls = ABCMeta.__new__(mcs, name, bases, namespace, **kwargs) # pylint: disable=too-many-function-args + + entry_point_group, entry_point = get_entry_point_from_class(namespace['__module__'], name) + + if entry_point_group is None or entry_point_group != 'aiida.groups': + newcls._type_string = None + message = 'no registered entry point for `{}` so its instances will not be storable.'.format(name) + warnings.warn(message) # pylint: disable=no-member + else: + newcls._type_string = entry_point.name # pylint: disable=protected-access + + return newcls class GroupTypeString(Enum): - """A simple enum of allowed group type strings.""" + """A simple enum of allowed group type strings. + .. deprecated:: 1.2.0 + This enum is deprecated and will be removed in `v2.0.0`. + """ UPFGROUP_TYPE = 'data.upf' IMPORTGROUP_TYPE = 'auto.import' VERDIAUTOGROUP_TYPE = 'auto.run' USER = 'user' -class Group(entities.Entity): +class Group(entities.Entity, metaclass=GroupMeta): """An AiiDA ORM implementation of group of nodes.""" class Collection(entities.Collection): @@ -54,21 +98,10 @@ def get_or_create(self, label=None, **kwargs): if not label: raise ValueError('Group label must be provided') - filters = {'label': label} - - if 'type_string' in kwargs: - if not isinstance(kwargs['type_string'], str): - raise exceptions.ValidationError( - 'type_string must be {}, you provided an object of type ' - '{}'.format(str, type(kwargs['type_string'])) - ) - - filters['type_string'] = kwargs['type_string'] - - res = self.find(filters=filters) + res = self.find(filters={'label': label}) if not res: - return Group(label, backend=self.backend, **kwargs).store(), True + return self.entity_type(label, backend=self.backend, **kwargs).store(), True if len(res) > 1: raise exceptions.MultipleObjectsError('More than one groups found in the database') @@ -83,12 +116,15 @@ def delete(self, id): # pylint: disable=invalid-name, redefined-builtin """ self._backend.groups.delete(id) - def __init__(self, label=None, user=None, description='', type_string=GroupTypeString.USER.value, backend=None): + def __init__(self, label=None, user=None, description='', type_string=None, backend=None): """ Create a new group. Either pass a dbgroup parameter, to reload a group from the DB (and then, no further parameters are allowed), or pass the parameters for the Group creation. + .. deprecated:: 1.2.0 + The parameter `type_string` will be removed in `v2.0.0` and is now determined automatically. + :param label: The group label, required on creation :type label: str @@ -105,12 +141,11 @@ def __init__(self, label=None, user=None, description='', type_string=GroupTypeS if not label: raise ValueError('Group label must be provided') - # Check that chosen type_string is allowed - if not isinstance(type_string, str): - raise exceptions.ValidationError( - 'type_string must be {}, you provided an object of type ' - '{}'.format(str, type(type_string)) - ) + if type_string is not None: + message = '`type_string` is deprecated because it is determined automatically, using default `core`' + warnings.warn(message) # pylint: disable=no-member + + type_string = self._type_string backend = backend or get_manager().get_backend() user = user or users.User.objects(backend).get_default() @@ -130,6 +165,13 @@ def __str__(self): return '"{}" [user-defined], of user {}'.format(self.label, self.user.email) + def store(self): + """Verify that the group is allowed to be stored, which is the case along as `type_string` is set.""" + if self._type_string is None: + raise exceptions.StoringNotAllowed('`type_string` is `None` so the group cannot be stored.') + + return super().store() + @property def label(self): """ @@ -295,11 +337,7 @@ def get(cls, **kwargs): filters = {} if 'type_string' in kwargs: - if not isinstance(kwargs['type_string'], str): - raise exceptions.ValidationError( - 'type_string must be {}, you provided an object of type ' - '{}'.format(str, type(kwargs['type_string'])) - ) + type_check(kwargs['type_string'], str) query = QueryBuilder() for key, val in kwargs.items(): @@ -382,3 +420,15 @@ def get_schema(): 'type': 'unicode' } } + + +class AutoGroup(Group): + """Group to be used to contain all nodes from an export archive that has been imported.""" + + +class ImportGroup(Group): + """Group to be used to contain all nodes from an export archive that has been imported.""" + + +class UpfFamily(Group): + """Group that represents a pseudo potential family containing `UpfData` nodes.""" diff --git a/aiida/orm/implementation/groups.py b/aiida/orm/implementation/groups.py index 74349e25e60..f39314060f5 100644 --- a/aiida/orm/implementation/groups.py +++ b/aiida/orm/implementation/groups.py @@ -101,7 +101,7 @@ def get_or_create(cls, *args, **kwargs): :return: (group, created) where group is the group (new or existing, in any case already stored) and created is a boolean saying """ - res = cls.query(name=kwargs.get('name'), type_string=kwargs.get('type_string')) + res = cls.query(name=kwargs.get('name')) if not res: return cls.create(*args, **kwargs), True diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py index d35e1b35eeb..33cf9b64210 100644 --- a/aiida/orm/nodes/data/upf.py +++ b/aiida/orm/nodes/data/upf.py @@ -8,20 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module of `Data` sub class to represent a pseudopotential single file in UPF format and related utilities.""" - import json import re from upf_to_json import upf_to_json - -from aiida.common.lang import classproperty -from aiida.orm import GroupTypeString from .singlefile import SinglefileData __all__ = ('UpfData',) -UPFGROUP_TYPE = GroupTypeString.UPFGROUP_TYPE.value - REGEX_UPF_VERSION = re.compile(r""" \s*.*)"> @@ -107,9 +101,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T nfiles = len(filenames) automatic_user = orm.User.objects.get_default() - group, group_created = orm.Group.objects.get_or_create( - label=group_label, type_string=UPFGROUP_TYPE, user=automatic_user - ) + group, group_created = orm.UpfFamily.objects.get_or_create(label=group_label, user=automatic_user) if group.user.email != automatic_user.email: raise UniquenessError( @@ -312,12 +304,6 @@ def get_or_create(cls, filepath, use_first=False, store_upf=True): return (pseudos[0], False) - @classproperty - def upffamily_type_string(cls): - """Return the type string used for UPF family groups.""" - # pylint: disable=no-self-argument,no-self-use - return UPFGROUP_TYPE - def store(self, *args, **kwargs): """Store the node, reparsing the file so that the md5 and the element are correctly reset.""" # pylint: disable=arguments-differ @@ -388,11 +374,11 @@ def set_file(self, file, filename=None): def get_upf_family_names(self): """Get the list of all upf family names to which the pseudo belongs.""" - from aiida.orm import Group + from aiida.orm import UpfFamily from aiida.orm import QueryBuilder query = QueryBuilder() - query.append(Group, filters={'type_string': {'==': self.upffamily_type_string}}, tag='group', project='label') + query.append(UpfFamily, tag='group', project='label') query.append(UpfData, filters={'id': {'==': self.id}}, with_group='group') return [label for label, in query.all()] @@ -465,9 +451,9 @@ def get_upf_group(cls, group_label): :param group_label: the family group label :return: the `Group` with the given label, if it exists """ - from aiida.orm import Group + from aiida.orm import UpfFamily - return Group.get(label=group_label, type_string=cls.upffamily_type_string) + return UpfFamily.get(label=group_label) @classmethod def get_upf_groups(cls, filter_elements=None, user=None): @@ -480,12 +466,12 @@ def get_upf_groups(cls, filter_elements=None, user=None): If defined, it should be either a `User` instance or the user email. :return: list of `Group` entities of type UPF. """ - from aiida.orm import Group + from aiida.orm import UpfFamily from aiida.orm import QueryBuilder from aiida.orm import User builder = QueryBuilder() - builder.append(Group, filters={'type_string': {'==': cls.upffamily_type_string}}, tag='group', project='*') + builder.append(UpfFamily, tag='group', project='*') if user: builder.append(User, filters={'email': {'==': user}}, with_group='group') @@ -496,7 +482,7 @@ def get_upf_groups(cls, filter_elements=None, user=None): if filter_elements is not None: builder.append(UpfData, filters={'attributes.element': {'in': filter_elements}}, with_group='group') - builder.order_by({Group: {'id': 'asc'}}) + builder.order_by({UpfFamily: {'id': 'asc'}}) return [group for group, in builder.all()] diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index 2abe6be0771..46e4bf3c7e2 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -54,6 +54,7 @@ class EntryPointFormat(enum.Enum): 'aiida.calculations': 'aiida.orm.nodes.process.calculation.calcjob', 'aiida.cmdline.data': 'aiida.cmdline.data', 'aiida.data': 'aiida.orm.nodes.data', + 'aiida.groups': 'aiida.orm.groups', 'aiida.node': 'aiida.orm.nodes', 'aiida.parsers': 'aiida.parsers.plugins', 'aiida.schedulers': 'aiida.schedulers.plugins', @@ -78,6 +79,7 @@ def validate_registered_entry_points(): # pylint: disable=invalid-name factory_mapping = { 'aiida.calculations': factories.CalculationFactory, 'aiida.data': factories.DataFactory, + 'aiida.groups': factories.GroupFactory, 'aiida.parsers': factories.ParserFactory, 'aiida.schedulers': factories.SchedulerFactory, 'aiida.transports': factories.TransportFactory, diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 6e5a9296e9d..1675ac6cb6a 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -14,8 +14,8 @@ from aiida.common.exceptions import InvalidEntryPointTypeError __all__ = ( - 'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'OrbitalFactory', 'ParserFactory', - 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' + 'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', 'OrbitalFactory', + 'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' ) @@ -107,6 +107,25 @@ def DbImporterFactory(entry_point_name): raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) +def GroupFactory(entry_point_name): + """Return the `Group` sub class registered under the given entry point. + + :param entry_point_name: the entry point name + :return: sub class of :py:class:`~aiida.orm.groups.Group` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. + """ + from aiida.orm import Group + + entry_point_group = 'aiida.groups' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (Group,) + + if isclass(entry_point) and issubclass(entry_point, Group): + return entry_point + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) + + def OrbitalFactory(entry_point_name): """Return the `Orbital` sub class registered under the given entry point. diff --git a/aiida/tools/importexport/common/config.py b/aiida/tools/importexport/common/config.py index 0baac376c9e..549c22be7d0 100644 --- a/aiida/tools/importexport/common/config.py +++ b/aiida/tools/importexport/common/config.py @@ -9,15 +9,13 @@ ########################################################################### # pylint: disable=invalid-name """ Configuration file for AiiDA Import/Export module """ - -from aiida.orm import Computer, Group, GroupTypeString, Node, User, Log, Comment +from aiida.orm import Computer, Group, Node, User, Log, Comment __all__ = ('EXPORT_VERSION',) # Current export version EXPORT_VERSION = '0.8' -IMPORTGROUP_TYPE = GroupTypeString.IMPORTGROUP_TYPE.value DUPL_SUFFIX = ' (Imported #{})' # The name of the subfolder in which the node files are stored diff --git a/aiida/tools/importexport/dbimport/backends/django/__init__.py b/aiida/tools/importexport/dbimport/backends/django/__init__.py index d97ad70d1db..aa463f5ffbe 100644 --- a/aiida/tools/importexport/dbimport/backends/django/__init__.py +++ b/aiida/tools/importexport/dbimport/backends/django/__init__.py @@ -21,10 +21,10 @@ from aiida.common.links import LinkType, validate_link_label from aiida.common.utils import grouper, get_object_from_string from aiida.orm.utils.repository import Repository -from aiida.orm import QueryBuilder, Node, Group +from aiida.orm import QueryBuilder, Node, Group, ImportGroup from aiida.tools.importexport.common import exceptions from aiida.tools.importexport.common.archive import extract_tree, extract_tar, extract_zip -from aiida.tools.importexport.common.config import DUPL_SUFFIX, IMPORTGROUP_TYPE, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER +from aiida.tools.importexport.common.config import DUPL_SUFFIX, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.config import ( NODE_ENTITY_NAME, GROUP_ENTITY_NAME, COMPUTER_ENTITY_NAME, USER_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME ) @@ -673,7 +673,7 @@ def import_data_dj( "Overflow of import groups (more than 100 import groups exists with basename '{}')" ''.format(basename) ) - group = Group(label=group_label, type_string=IMPORTGROUP_TYPE).store() + group = ImportGroup(label=group_label).store() # Add all the nodes to the new group # TODO: decide if we want to return the group label diff --git a/aiida/tools/importexport/dbimport/backends/sqla/__init__.py b/aiida/tools/importexport/dbimport/backends/sqla/__init__.py index f08de125ec6..2e800b13610 100644 --- a/aiida/tools/importexport/dbimport/backends/sqla/__init__.py +++ b/aiida/tools/importexport/dbimport/backends/sqla/__init__.py @@ -20,13 +20,13 @@ from aiida.common.folders import SandboxFolder, RepositoryFolder from aiida.common.links import LinkType from aiida.common.utils import get_object_from_string -from aiida.orm import QueryBuilder, Node, Group, WorkflowNode, CalculationNode, Data +from aiida.orm import QueryBuilder, Node, Group, ImportGroup from aiida.orm.utils.links import link_triple_exists, validate_link from aiida.orm.utils.repository import Repository from aiida.tools.importexport.common import exceptions from aiida.tools.importexport.common.archive import extract_tree, extract_tar, extract_zip -from aiida.tools.importexport.common.config import DUPL_SUFFIX, IMPORTGROUP_TYPE, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER +from aiida.tools.importexport.common.config import DUPL_SUFFIX, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.config import ( NODE_ENTITY_NAME, GROUP_ENTITY_NAME, COMPUTER_ENTITY_NAME, USER_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME ) @@ -664,7 +664,7 @@ def import_data_sqla( "Overflow of import groups (more than 100 import groups exists with basename '{}')" ''.format(basename) ) - group = Group(label=group_label, type_string=IMPORTGROUP_TYPE) + group = ImportGroup(label=group_label) session.add(group.backend_entity._dbmodel) # Adding nodes to group avoiding the SQLA ORM to increase speed diff --git a/setup.json b/setup.json index de5f83b295a..ba42fd31369 100644 --- a/setup.json +++ b/setup.json @@ -158,6 +158,12 @@ "structure = aiida.orm.nodes.data.structure:StructureData", "upf = aiida.orm.nodes.data.upf:UpfData" ], + "aiida.groups": [ + "core = aiida.orm.groups:Group", + "core.auto = aiida.orm.groups:AutoGroup", + "core.import = aiida.orm.groups:ImportGroup", + "core.upf = aiida.orm.groups:UpfFamily" + ], "aiida.node": [ "data = aiida.orm.nodes.data.data:Data", "process = aiida.orm.nodes.process.process:ProcessNode", diff --git a/tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py b/tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py new file mode 100644 index 00000000000..b2eebe70a74 --- /dev/null +++ b/tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=import-error,no-name-in-module,invalid-name +"""Test migration of `type_string` after the `Group` class became pluginnable.""" + +from .test_migrations_common import TestMigrations + + +class TestGroupTypeStringMigration(TestMigrations): + """Test migration of `type_string` after the `Group` class became pluginnable.""" + + migrate_from = '0043_default_link_label' + migrate_to = '0044_dbgroup_type_string' + + def setUpBeforeMigration(self): + DbGroup = self.apps.get_model('db', 'DbGroup') + + # test user group type_string: 'user' -> 'core' + group_user = DbGroup(label='01', user_id=self.default_user.id, type_string='user') + group_user.save() + self.group_user_pk = group_user.pk + + # test data.upf group type_string: 'data.upf' -> 'core.upf' + group_data_upf = DbGroup(label='02', user_id=self.default_user.id, type_string='data.upf') + group_data_upf.save() + self.group_data_upf_pk = group_data_upf.pk + + # test auto.import group type_string: 'auto.import' -> 'core.import' + group_autoimport = DbGroup(label='03', user_id=self.default_user.id, type_string='auto.import') + group_autoimport.save() + self.group_autoimport_pk = group_autoimport.pk + + # test auto.run group type_string: 'auto.run' -> 'core.run' + group_autorun = DbGroup(label='04', user_id=self.default_user.id, type_string='auto.run') + group_autorun.save() + self.group_autorun_pk = group_autorun.pk + + def test_group_string_update(self): + """Test that the type_string were updated correctly.""" + DbGroup = self.apps.get_model('db', 'DbGroup') + + # 'user' -> 'core' + group_user = DbGroup.objects.get(pk=self.group_user_pk) + self.assertEqual(group_user.type_string, 'core') + + # 'data.upf' -> 'core.upf' + group_data_upf = DbGroup.objects.get(pk=self.group_data_upf_pk) + self.assertEqual(group_data_upf.type_string, 'core.upf') + + # 'auto.import' -> 'core.import' + group_autoimport = DbGroup.objects.get(pk=self.group_autoimport_pk) + self.assertEqual(group_autoimport.type_string, 'core.import') + + # 'auto.run' -> 'core.run' + group_autorun = DbGroup.objects.get(pk=self.group_autorun_pk) + self.assertEqual(group_autorun.type_string, 'core.run') diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 8e2046f2932..8a894f1b561 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -1642,3 +1642,69 @@ def test_data_migrated(self): finally: session.close() + + +class TestGroupTypeStringMigration(TestMigrationsSQLA): + """Test the migration that renames the DbGroup type strings.""" + + migrate_from = '118349c10896' # 118349c10896_default_link_label.py + migrate_to = 'bf591f31dd12' # bf591f31dd12_dbgroup_type_string.py + + def setUpBeforeMigration(self): + """Create the DbGroups with the old type strings.""" + DbGroup = self.get_current_table('db_dbgroup') # pylint: disable=invalid-name + DbUser = self.get_current_table('db_dbuser') # pylint: disable=invalid-name + + with self.get_session() as session: + try: + default_user = DbUser(email='{}@aiida.net'.format(self.id())) + session.add(default_user) + session.commit() + + # test user group type_string: 'user' -> 'core' + group_user = DbGroup(label='01', user_id=default_user.id, type_string='user') + session.add(group_user) + # test data.upf group type_string: 'data.upf' -> 'core.upf' + group_data_upf = DbGroup(label='02', user_id=default_user.id, type_string='data.upf') + session.add(group_data_upf) + # test auto.import group type_string: 'auto.import' -> 'core.import' + group_autoimport = DbGroup(label='03', user_id=default_user.id, type_string='auto.import') + session.add(group_autoimport) + # test auto.run group type_string: 'auto.run' -> 'core.run' + group_autorun = DbGroup(label='04', user_id=default_user.id, type_string='auto.run') + session.add(group_autorun) + + session.commit() + + # Store values for later tests + self.group_user_pk = group_user.id + self.group_data_upf_pk = group_data_upf.id + self.group_autoimport_pk = group_autoimport.id + self.group_autorun_pk = group_autorun.id + + finally: + session.close() + + def test_group_string_update(self): + """Test that the type strings are properly migrated.""" + DbGroup = self.get_current_table('db_dbgroup') # pylint: disable=invalid-name + + with self.get_session() as session: + try: + # test user group type_string: 'user' -> 'core' + group_user = session.query(DbGroup).filter(DbGroup.id == self.group_user_pk).one() + self.assertEqual(group_user.type_string, 'core') + + # test data.upf group type_string: 'data.upf' -> 'core.upf' + group_data_upf = session.query(DbGroup).filter(DbGroup.id == self.group_data_upf_pk).one() + self.assertEqual(group_data_upf.type_string, 'core.upf') + + # test auto.import group type_string: 'auto.import' -> 'core.import' + group_autoimport = session.query(DbGroup).filter(DbGroup.id == self.group_autoimport_pk).one() + self.assertEqual(group_autoimport.type_string, 'core.import') + + # test auto.run group type_string: 'auto.run' -> 'core.run' + group_autorun = session.query(DbGroup).filter(DbGroup.id == self.group_autorun_pk).one() + self.assertEqual(group_autorun.type_string, 'core.run') + finally: + session.close() diff --git a/tests/cmdline/commands/test_group.py b/tests/cmdline/commands/test_group.py index 43024206336..ab79f650b11 100644 --- a/tests/cmdline/commands/test_group.py +++ b/tests/cmdline/commands/test_group.py @@ -165,7 +165,7 @@ def test_show(self): self.assertClickResultNoException(result) for grpline in [ - 'Group label', 'dummygroup1', 'Group type_string', 'user', 'Group description', '' + 'Group label', 'dummygroup1', 'Group type_string', 'core', 'Group description', '' ]: self.assertIn(grpline, result.output) diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 78c858420f7..4ed690bb20a 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -9,6 +9,7 @@ ########################################################################### """Tests for `verdi run`.""" import tempfile +import textwrap import warnings from click.testing import CliRunner @@ -31,21 +32,22 @@ def test_run_workfunction(self): that are defined within the script will fail, as the inspect module will not correctly be able to determin the full path of the source file. """ - from aiida.orm import load_node - from aiida.orm import WorkFunctionNode + from aiida.orm import load_node, WorkFunctionNode - script_content = """ -#!/usr/bin/env python -from aiida.engine import workfunction + script_content = textwrap.dedent( + """\ + #!/usr/bin/env python + from aiida.engine import workfunction -@workfunction -def wf(): - pass + @workfunction + def wf(): + pass -if __name__ == '__main__': - result, node = wf.run_get_node() - print(node.pk) - """ + if __name__ == '__main__': + result, node = wf.run_get_node() + print(node.pk) + """ + ) # If `verdi run` is not setup correctly, the script above when run with `verdi run` will fail, because when # the engine will try to create the node for the workfunction and create a copy of its sourcefile, namely the @@ -77,9 +79,8 @@ def setUp(self): super().setUp() self.cli_runner = CliRunner() - # I need to disable the global variable of this test environment, - # because invoke is just calling the function and therefore inheriting - # the global variable + # I need to disable the global variable of this test environment, because invoke is just calling the function + # and therefore inheriting the global variable self._old_autogroup = autogroup.CURRENT_AUTOGROUP autogroup.CURRENT_AUTOGROUP = None @@ -92,12 +93,15 @@ def tearDown(self): def test_autogroup(self): """Check if the autogroup is properly generated.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -111,7 +115,7 @@ def test_autogroup(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -119,12 +123,16 @@ def test_autogroup(self): def test_autogroup_custom_label(self): """Check if the autogroup is properly generated with the label specified.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" autogroup_label = 'SOME_group_LABEL' with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -138,7 +146,7 @@ def test_autogroup_custom_label(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -147,12 +155,15 @@ def test_autogroup_custom_label(self): def test_no_autogroup(self): """Check if the autogroup is not generated if ``verdi run`` is asked not to.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -166,61 +177,64 @@ def test_no_autogroup(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual(len(all_auto_groups), 0, 'There should be no autogroup generated') def test_autogroup_filter_class(self): # pylint: disable=too-many-locals """Check if the autogroup is properly generated but filtered classes are skipped.""" - from aiida.orm import QueryBuilder, Node, Group, load_node - - script_content = """import sys -from aiida.orm import Computer, Int, ArrayData, KpointsData, CalculationNode, WorkflowNode -from aiida.plugins import CalculationFactory -from aiida.engine import run_get_node -ArithmeticAdd = CalculationFactory('arithmetic.add') - -computer = Computer( - name='localhost-example-{}'.format(sys.argv[1]), - hostname='localhost', - description='my computer', - transport_type='local', - scheduler_type='direct', - workdir='/tmp' -).store() -computer.configure() - -code = Code( - input_plugin_name='arithmetic.add', - remote_computer_exec=[computer, '/bin/true']).store() -inputs = { - 'x': Int(1), - 'y': Int(2), - 'code': code, - 'metadata': { - 'options': { - 'resources': { - 'num_machines': 1, - 'num_mpiprocs_per_machine': 1 + from aiida.orm import Code, QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + import sys + from aiida.orm import Computer, Int, ArrayData, KpointsData, CalculationNode, WorkflowNode + from aiida.plugins import CalculationFactory + from aiida.engine import run_get_node + ArithmeticAdd = CalculationFactory('arithmetic.add') + + computer = Computer( + name='localhost-example-{}'.format(sys.argv[1]), + hostname='localhost', + description='my computer', + transport_type='local', + scheduler_type='direct', + workdir='/tmp' + ).store() + computer.configure() + + code = Code( + input_plugin_name='arithmetic.add', + remote_computer_exec=[computer, '/bin/true']).store() + inputs = { + 'x': Int(1), + 'y': Int(2), + 'code': code, + 'metadata': { + 'options': { + 'resources': { + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1 + } + } + } } - } - } -} - -node1 = KpointsData().store() -node2 = ArrayData().store() -node3 = Int(3).store() -node4 = CalculationNode().store() -node5 = WorkflowNode().store() -_, node6 = run_get_node(ArithmeticAdd, **inputs) -print(node1.pk) -print(node2.pk) -print(node3.pk) -print(node4.pk) -print(node5.pk) -print(node6.pk) -""" - from aiida.orm import Code + + node1 = KpointsData().store() + node2 = ArrayData().store() + node3 = Int(3).store() + node4 = CalculationNode().store() + node5 = WorkflowNode().store() + _, node6 = run_get_node(ArithmeticAdd, **inputs) + print(node1.pk) + print(node2.pk) + print(node3.pk) + print(node4.pk) + print(node5.pk) + print(node6.pk) + """ + ) + Code() for idx, ( flags, @@ -283,27 +297,27 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals _ = load_node(pk6) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk1}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_kptdata = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk2}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_arraydata = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk3}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_int = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk4}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_calc = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk5}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_wf = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk6}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_calcarithmetic = queryb.all() self.assertEqual( @@ -339,12 +353,16 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals def test_autogroup_clashing_label(self): """Check if the autogroup label is properly (re)generated when it clashes with an existing group name.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" autogroup_label = 'SOME_repeated_group_LABEL' with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -358,7 +376,7 @@ def test_autogroup_clashing_label(self): pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -374,7 +392,7 @@ def test_autogroup_clashing_label(self): pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -383,12 +401,15 @@ def test_autogroup_clashing_label(self): def test_legacy_autogroup_name(self): """Check if the autogroup is properly generated when using the legacy --group-name flag.""" - from aiida.orm import QueryBuilder, Node, Group, load_node - - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) group_label = 'legacy-group-name' with tempfile.NamedTemporaryFile(mode='w+') as fhandle: @@ -409,7 +430,7 @@ def test_legacy_autogroup_name(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' diff --git a/tests/orm/data/test_upf.py b/tests/orm/data/test_upf.py index 228f8d9b77c..02922bc60f1 100644 --- a/tests/orm/data/test_upf.py +++ b/tests/orm/data/test_upf.py @@ -10,7 +10,6 @@ """ This module contains tests for UpfData and UpfData related functions. """ - import errno import tempfile import shutil @@ -95,8 +94,8 @@ def setUp(self): def tearDown(self): """Delete all groups and destroy the temporary directory created.""" - for group in orm.Group.objects.find(filters={'type_string': orm.GroupTypeString.UPFGROUP_TYPE.value}): - orm.Group.objects.delete(group.pk) + for group in orm.UpfFamily.objects.find(): + orm.UpfFamily.objects.delete(group.pk) try: shutil.rmtree(self.temp_dir) @@ -122,32 +121,31 @@ def test_get_upf_family_names(self): """Test the `UpfData.get_upf_family_names` method.""" label = 'family' - family, _ = orm.Group.objects.get_or_create(label=label, type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + family, _ = orm.UpfFamily.objects.get_or_create(label=label) family.add_nodes([self.pseudo_barium]) family.store() - self.assertEqual({group.label for group in orm.UpfData.get_upf_groups()}, {label}) + self.assertEqual({group.label for group in orm.UpfFamily.objects.all()}, {label}) self.assertEqual(self.pseudo_barium.get_upf_family_names(), [label]) def test_get_upf_groups(self): """Test the `UpfData.get_upf_groups` class method.""" - type_string = orm.GroupTypeString.UPFGROUP_TYPE.value label_01 = 'family_01' label_02 = 'family_02' user = orm.User(email='alternate@localhost').store() - self.assertEqual(orm.UpfData.get_upf_groups(), []) + self.assertEqual(orm.UpfFamily.objects.all(), []) # Create group with default user and add `Ba` pseudo - family_01, _ = orm.Group.objects.get_or_create(label=label_01, type_string=type_string) + family_01, _ = orm.UpfFamily.objects.get_or_create(label=label_01) family_01.add_nodes([self.pseudo_barium]) family_01.store() self.assertEqual({group.label for group in orm.UpfData.get_upf_groups()}, {label_01}) # Create group with different user and add `O` pseudo - family_02, _ = orm.Group.objects.get_or_create(label=label_02, type_string=type_string, user=user) + family_02, _ = orm.UpfFamily.objects.get_or_create(label=label_02, user=user) family_02.add_nodes([self.pseudo_oxygen]) family_02.store() diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index ce2797daad2..67b189195cf 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test for the Group ORM class.""" +import pytest + from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions @@ -272,3 +274,54 @@ def test_group_uuid_hashing_for_querybuidler(self): # And that the results are correct self.assertEqual(builder.count(), 1) self.assertEqual(builder.first()[0], group.id) + + +class TestGroupsSubclasses(AiidaTestCase): + """Test rules around creating `Group` subclasses.""" + + @staticmethod + def test_creation_registered(): + """Test rules around creating registered `Group` subclasses.""" + group = orm.AutoGroup('some-label') + assert isinstance(group, orm.AutoGroup) + assert group.type_string == 'core.auto' + + group, _ = orm.AutoGroup.objects.get_or_create('some-auto-group') + assert isinstance(group, orm.AutoGroup) + assert group.type_string == 'core.auto' + + @staticmethod + def test_creation_unregistered(): + """Test rules around creating `Group` subclasses without a registered entry point.""" + + # Defining an unregistered subclas should issue a warning and its type string should be set to `None` + with pytest.warns(UserWarning): + + class SubGroup(orm.Group): + pass + + assert SubGroup._type_string is None # pylint: disable=protected-access + + # Creating an instance is allowed + instance = SubGroup(label='subgroup') + assert instance._type_string is None # pylint: disable=protected-access + + # Storing the instance, however, is forbidden and should raise + with pytest.raises(exceptions.StoringNotAllowed): + instance.store() + + @staticmethod + def test_loading_unregistered(): + """Test rules around loading `Group` subclasses without a registered entry point. + + Storing instances of unregistered subclasses is not allowed so we have to create one sneakily by instantiating + a normal group and manipulating the type string directly on the database model. + """ + group = orm.Group(label='group') + group.backend_entity.dbmodel.type_string = 'unregistered.subclass' + group.store() + + with pytest.warns(UserWarning): + loaded = orm.load_group(group.pk) + + assert isinstance(loaded, orm.Group) diff --git a/tests/tools/graph/test_age.py b/tests/tools/graph/test_age.py index dddf2323c20..538087c7d71 100644 --- a/tests/tools/graph/test_age.py +++ b/tests/tools/graph/test_age.py @@ -494,7 +494,7 @@ def test_groups(self): # Rule that only gets nodes connected by the same group queryb = orm.QueryBuilder() queryb.append(orm.Node, tag='nodes_in_set') - queryb.append(orm.Group, with_node='nodes_in_set', tag='groups_considered', filters={'type_string': 'user'}) + queryb.append(orm.Group, with_node='nodes_in_set', tag='groups_considered') queryb.append(orm.Data, with_group='groups_considered') initial_node = [node2.id] @@ -513,7 +513,7 @@ def test_groups(self): # But two rules chained should get both nodes and groups... queryb = orm.QueryBuilder() queryb.append(orm.Node, tag='nodes_in_set') - queryb.append(orm.Group, with_node='nodes_in_set', filters={'type_string': 'user'}) + queryb.append(orm.Group, with_node='nodes_in_set') rule1 = UpdateRule(queryb) queryb = orm.QueryBuilder() @@ -569,7 +569,7 @@ def test_groups(self): qb1 = orm.QueryBuilder() qb1.append(orm.Node, tag='nodes_in_set') - qb1.append(orm.Group, with_node='nodes_in_set', filters={'type_string': 'user'}) + qb1.append(orm.Group, with_node='nodes_in_set') rule1 = UpdateRule(qb1, track_edges=True) qb2 = orm.QueryBuilder() diff --git a/tests/tools/importexport/test_prov_redesign.py b/tests/tools/importexport/test_prov_redesign.py index 37f9a485a0d..5ef849c51c8 100644 --- a/tests/tools/importexport/test_prov_redesign.py +++ b/tests/tools/importexport/test_prov_redesign.py @@ -229,7 +229,7 @@ def test_group_name_and_type_change(self, temp_dir): groups_type_string = [g.type_string for g in [group_user, group_upf]] # Assert correct type strings exists prior to export - self.assertListEqual(groups_type_string, ['user', 'data.upf']) + self.assertListEqual(groups_type_string, ['core', 'core.upf']) # Export node filename = os.path.join(temp_dir, 'export.tar.gz') @@ -268,4 +268,4 @@ def test_group_name_and_type_change(self, temp_dir): # Check type_string content of "import group" import_group = orm.load_group(imported_groups_uuid[0]) - self.assertEqual(import_group.type_string, 'auto.import') + self.assertEqual(import_group.type_string, 'core.import')