Skip to content

Commit

Permalink
Make verdi import -G (group) option function work as intended. (#2873)
Browse files Browse the repository at this point in the history
Set `verdi import`'s Group option to the general `options.GROUP()`. But
allowing to create new Group if needed. Only a Group object/entity can
be passed. Change parameter 'user_group' to 'group' in import functions.
The handling of 'group' in `importexport.py` has been optimized.
  • Loading branch information
CasperWA authored and sphuber committed Jun 18, 2019
1 parent 5eaab43 commit cfcf213
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 38 deletions.
74 changes: 74 additions & 0 deletions aiida/backends/tests/cmdline/commands/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from aiida.backends.testbase import AiidaTestCase
from aiida.cmdline.commands import cmd_import
from aiida.orm import Group


def get_archive_file(archive):
Expand Down Expand Up @@ -78,6 +79,79 @@ def test_import_archive(self):
self.assertIsNone(result.exception, result.output)
self.assertEqual(result.exit_code, 0, result.output)

@unittest.skip("Reenable when issue #2426 has been solved (migrate exported files from 0.3 to 0.4)")
def test_import_to_group(self):
"""
Test import to existing Group and that Nodes are added correctly for multiple imports
of the same, as well as separate, archives.
"""
archives = [
get_archive_file('calcjob/arithmetic.add.aiida'),
get_archive_file('export/migrate/export_v0.4.aiida')
]

group_label = "import_madness"
group = Group(group_label).store()

self.assertTrue(group.is_empty, msg="The Group should be empty.")

# Invoke `verdi import`, making sure there are no exceptions
options = ['-G', group.label] + [archives[0]]
result = self.cli_runner.invoke(cmd_import.cmd_import, options)
self.assertIsNone(result.exception, msg=result.output)
self.assertEqual(result.exit_code, 0, msg=result.output)

self.assertFalse(group.is_empty, msg="The Group should no longer be empty.")

nodes_in_group = group.count()

# Invoke `verdi import` again, making sure Group count doesn't change
options = ['-G', group.label] + [archives[0]]
result = self.cli_runner.invoke(cmd_import.cmd_import, options)
self.assertIsNone(result.exception, msg=result.output)
self.assertEqual(result.exit_code, 0, msg=result.output)

self.assertEqual(
group.count(),
nodes_in_group,
msg="The Group count should not have changed from {}. Instead it is now {}".format(
nodes_in_group, group.count()))

# Invoke `verdi import` again with new archive, making sure Group count is upped
options = ['-G', group.label] + [archives[1]]
result = self.cli_runner.invoke(cmd_import.cmd_import, options)
self.assertIsNone(result.exception, msg=result.output)
self.assertEqual(result.exit_code, 0, msg=result.output)

self.assertGreater(
group.count(),
nodes_in_group,
msg="There should now be more than {} nodes in group {} , instead there are {}".format(
nodes_in_group, group_label, group.count()))

@unittest.skip("Reenable when issue #2426 has been solved (migrate exported files from 0.3 to 0.4)")
def test_import_make_new_group(self):
"""Make sure imported entities are saved in new Group"""
# Initialization
group_label = "new_group_for_verdi_import"
archives = [get_archive_file('export/migrate/export_v0.4_simple.aiida')]

# Check Group does not already exist
group_search = Group.objects.find(filters={'label': group_label})
self.assertEqual(
len(group_search), 0, msg="A Group with label '{}' already exists, this shouldn't be.".format(group_label))

# Invoke `verdi import`, making sure there are no exceptions
options = ['-G', group_label] + archives
result = self.cli_runner.invoke(cmd_import.cmd_import, options)
self.assertIsNone(result.exception, msg=result.output)
self.assertEqual(result.exit_code, 0, msg=result.output)

# Make sure new Group was created
(group, new_group) = Group.objects.get_or_create(group_label)
self.assertFalse(new_group, msg="The Group should not have been created now, but instead when it was imported.")
self.assertFalse(group.is_empty, msg="The Group should not be empty.")

@unittest.skip("Reenable when issue #2426 has been solved (migrate exported files from 0.3 to 0.4)")
def test_comment_mode(self):
"""
Expand Down
41 changes: 41 additions & 0 deletions aiida/backends/tests/test_export_and_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,47 @@ def test_group_import_existing(self, temp_dir):
builder.append(orm.Group, filters={'label': {'like': grouplabel + '%'}})
self.assertEqual(builder.count(), 2)

@with_temp_dir
def test_import_to_group(self, temp_dir):
"""Test `group` parameter
Make sure an unstored Group is stored by the import function, forwarding the Group object.
Make sure the Group is correctly handled and used for imported nodes.
"""
from aiida.orm import load_group

# Create Nodes to export
data1 = orm.Data().store()
data2 = orm.Data().store()
node_uuids = [node.uuid for node in [data1, data2]]

# Export Nodes
filename = os.path.join(temp_dir, "export.aiida")
export([data1, data2], outfile=filename, silent=True)
self.reset_database()

# Create Group, do not store
group_label = "import_madness"
group = orm.Group(label=group_label)
group_uuid = group.uuid

# Try to import to this Group, providing only label - this should fail
with self.assertRaises(TypeError, msg="Labels should no longer be passable to `import_data`") as exc:
import_data(filename, group=group_label, silent=True)
exc = exc.exception
self.assertEqual(str(exc), "group must be a Group entity",
msg="The error message should be the same for both backends.")

# Import properly now, providing the Group object
import_data(filename, group=group, silent=True)

# Check Group for content
builder = orm.QueryBuilder().append(orm.Group, filters={'label': group_label}, project='uuid')
self.assertEqual(builder.count(), 1, msg="There should be exactly one Group with label {}. "
"Instead {} was found.".format(group_label, builder.count()))
imported_group = load_group(builder.all()[0][0])
self.assertEqual(imported_group.uuid, group_uuid)
for node in imported_group.nodes:
self.assertIn(node.uuid, node_uuids)

class TestCalculations(AiidaTestCase):
"""Test ex-/import cases related to Calculations"""
Expand Down
4 changes: 1 addition & 3 deletions aiida/cmdline/commands/cmd_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def _migrate_archive(ctx, temp_folder, file_to_import, archive, non_interactive,
cls=options.MultipleValueOption,
help="Discover all URL targets pointing to files with the .aiida extension for these HTTP addresses. "
"Automatically discovered archive URLs will be downloadeded and added to ARCHIVES for importing")
@click.option(
'-G',
'--group',
@options.GROUP(
type=GroupParamType(create_if_not_exist=True),
help='Specify group to which all the import nodes will be added. If such a group does not exist, it will be'
' created automatically.')
Expand Down
79 changes: 44 additions & 35 deletions aiida/orm/importexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def import_data(in_path, group=None, silent=False, **kwargs):
detect the compression format (zip, tar.gz, tar.bz2, ...) and calls the
correct function.
:param in_path: the path to a file or folder that can be imported in AiiDA
:param group: Group wherein all imported Nodes will be placed.
:param extras_mode_existing: 3 letter code that will identify what to do with the extras import.
The first letter acts on extras that are present in the original node and not present in the imported node.
Can be either:
Expand All @@ -564,14 +565,14 @@ def import_data(in_path, group=None, silent=False, **kwargs):
from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA

if configuration.PROFILE.database_backend == BACKEND_SQLA:
return import_data_sqla(in_path, user_group=group, silent=silent, **kwargs)
return import_data_sqla(in_path, group=group, silent=silent, **kwargs)
elif configuration.PROFILE.database_backend == BACKEND_DJANGO:
return import_data_dj(in_path, user_group=group, silent=silent, **kwargs)
return import_data_dj(in_path, group=group, silent=silent, **kwargs)
else:
raise Exception("Unknown backend: {}".format(configuration.PROFILE.database_backend))


def import_data_dj(in_path, user_group=None, ignore_unknown_nodes=False,
def import_data_dj(in_path, group=None, ignore_unknown_nodes=False,
extras_mode_existing='kcl', extras_mode_new='import',
comment_mode='newest', silent=False):
"""
Expand All @@ -580,6 +581,7 @@ def import_data_dj(in_path, user_group=None, ignore_unknown_nodes=False,
detect the compression format (zip, tar.gz, tar.bz2, ...) and calls the
correct function.
:param in_path: the path to a file or folder that can be imported in AiiDA
:param group: Group wherein all imported Nodes will be placed.
:param extras_mode_existing: 3 letter code that will identify what to do with the extras import.
The first letter acts on extras that are present in the original node and not present in the imported node.
Can be either:
Expand Down Expand Up @@ -625,6 +627,13 @@ def import_data_dj(in_path, user_group=None, ignore_unknown_nodes=False,
# The returned dictionary with new and existing nodes and links
ret_dict = {}

# Initial check(s)
if group:
if not isinstance(group, Group):
raise TypeError("group must be a Group entity")
elif not group.is_stored:
group.store()

################
# EXTRACT DATA #
################
Expand Down Expand Up @@ -688,10 +697,10 @@ def import_data_dj(in_path, user_group=None, ignore_unknown_nodes=False,
# store them in a reverse table
# I break up the query due to SQLite limitations..
relevant_db_nodes = {}
for group in grouper(999, linked_nodes):
for group_ in grouper(999, linked_nodes):
relevant_db_nodes.update({n.uuid: n for n in
models.DbNode.objects.filter(
uuid__in=group)})
uuid__in=group_)})

db_nodes_uuid = set(relevant_db_nodes.keys())
# ~ dbnode_model = get_class_string(models.DbNode)
Expand Down Expand Up @@ -1134,11 +1143,11 @@ def import_data_dj(in_path, user_group=None, ignore_unknown_nodes=False,
import_groups = data['groups_uuid']
for groupuuid, groupnodes in import_groups.items():
# TODO: cache these to avoid too many queries
group = models.DbGroup.objects.get(uuid=groupuuid)
group_ = models.DbGroup.objects.get(uuid=groupuuid)
nodes_to_store = [dbnode_reverse_mappings[node_uuid]
for node_uuid in groupnodes]
if nodes_to_store:
group.dbnodes.add(*nodes_to_store)
group_.dbnodes.add(*nodes_to_store)

######################################################
# Put everything in a specific group
Expand All @@ -1157,14 +1166,10 @@ def import_data_dj(in_path, user_group=None, ignore_unknown_nodes=False,

# So that we do not create empty groups
if pks_for_group:
# If user specified a group, import all things in it
if user_group:
group = user_group[0]
else:
# Get an unique name for the import group, based on the
# current (local) time
basename = timezone.localtime(timezone.now()).strftime(
"%Y%m%d-%H%M%S")
# If user specified a group, import all things into it
if not group:
# Get an unique name for the import group, based on the current (local) time
basename = timezone.localtime(timezone.now()).strftime("%Y%m%d-%H%M%S")
counter = 0
created = False
while not created:
Expand All @@ -1179,15 +1184,14 @@ def import_data_dj(in_path, user_group=None, ignore_unknown_nodes=False,
counter += 1

# Add all the nodes to the new group
# TODO: decide if we want to return the group label
nodes = [entry[0] for entry in QueryBuilder().append(Node, filters={'id': {'in': pks_for_group}}).all()]
group.add_nodes(nodes)

if not silent:
print("IMPORTED NODES GROUPED IN IMPORT GROUP NAMED '{}'".format(group.label))
print("IMPORTED NODES ARE GROUPED IN THE GROUP LABELED '{}'".format(group.label))
else:
if not silent:
print("NO DBNODES TO IMPORT, SO NO GROUP CREATED")
print("NO NODES TO IMPORT, SO NO GROUP CREATED, IF IT DID NOT ALREADY EXIST")

if not silent:
print("*** WARNING: MISSING EXISTING UUID CHECKS!!")
Expand All @@ -1213,7 +1217,7 @@ def validate_uuid(given_uuid):
return str(parsed_uuid) == given_uuid


def import_data_sqla(in_path, user_group=None, ignore_unknown_nodes=False,
def import_data_sqla(in_path, group=None, ignore_unknown_nodes=False,
extras_mode_existing='kcl', extras_mode_new='import',
comment_mode='newest', silent=False):
"""
Expand All @@ -1222,6 +1226,7 @@ def import_data_sqla(in_path, user_group=None, ignore_unknown_nodes=False,
detect the compression format (zip, tar.gz, tar.bz2, ...) and calls the
correct function.
:param in_path: the path to a file or folder that can be imported in AiiDA
:param group: Group wherein all imported Nodes will be placed.
:param extras_mode_existing: 3 letter code that will identify what to do with the extras import.
The first letter acts on extras that are present in the original node and not present in the imported node.
Can be either:
Expand Down Expand Up @@ -1266,6 +1271,13 @@ def import_data_sqla(in_path, user_group=None, ignore_unknown_nodes=False,
# The returned dictionary with new and existing nodes and links
ret_dict = {}

# Initial check(s)
if group:
if not isinstance(group, Group):
raise TypeError("group must be a Group entity")
elif not group.is_stored:
group.store()

################
# EXTRACT DATA #
################
Expand Down Expand Up @@ -1857,14 +1869,14 @@ def import_data_sqla(in_path, user_group=None, ignore_unknown_nodes=False,
# # TODO: cache these to avoid too many queries
qb_group = QueryBuilder().append(
Group, filters={'uuid': {'==': groupuuid}})
group = qb_group.first()[0]
group_ = qb_group.first()[0]
nodes_ids_to_add = [dbnode_reverse_mappings[node_uuid]
for node_uuid in groupnodes]
qb_nodes = QueryBuilder().append(
Node, filters={'id': {'in': nodes_ids_to_add}})
# Adding nodes to group avoiding the SQLA ORM to increase speed
nodes_to_add = [n[0].backend_entity for n in qb_nodes.all()]
group.backend_entity.add_nodes(nodes_to_add, skip_orm=True)
group_.backend_entity.add_nodes(nodes_to_add, skip_orm=True)

######################################################
# Put everything in a specific group
Expand All @@ -1880,13 +1892,9 @@ def import_data_sqla(in_path, user_group=None, ignore_unknown_nodes=False,
# So that we do not create empty groups
if pks_for_group:
# If user specified a group, import all things in it
if user_group:
group = user_group[0]
else:
# Get an unique name for the import group, based on the
# current (local) time
basename = timezone.localtime(timezone.now()).strftime(
"%Y%m%d-%H%M%S")
if not group:
# Get an unique name for the import group, based on the current (local) time
basename = timezone.localtime(timezone.now()).strftime("%Y%m%d-%H%M%S")
counter = 0
created = False
while not created:
Expand All @@ -1895,24 +1903,25 @@ def import_data_sqla(in_path, user_group=None, ignore_unknown_nodes=False,
else:
group_label = "{}_{}".format(basename, counter)

group = Group(label=group_label,
type_string=IMPORTGROUP_TYPE)
group = Group(label=group_label, type_string=IMPORTGROUP_TYPE)
from aiida.backends.sqlalchemy.models.group import DbGroup
if session.query(DbGroup).filter(
DbGroup.label == group.backend_entity._dbmodel.label).count() == 0:
session.add(group.backend_entity._dbmodel)
DbGroup.label == group.backend_entity._dbmodel.label).count() == 0: # pylint: disable=protected-access
session.add(group.backend_entity._dbmodel) # pylint: disable=protected-access
created = True
else:
counter += 1

# Adding nodes to group avoiding the SQLA ORM to increase speed
nodes = [entry[0].backend_entity for entry in QueryBuilder().append(Node, filters={'id': {'in': pks_for_group}}).all()]
nodes = [entry[0].backend_entity for entry in QueryBuilder().append(
Node, filters={'id': {'in': pks_for_group}}).all()]
group.backend_entity.add_nodes(nodes, skip_orm=True)

if not silent:
print("IMPORTED NODES GROUPED IN IMPORT GROUP NAMED '{}'".format(group.label))
print("IMPORTED NODES ARE GROUPED IN THE GROUP LABELED '{}'".format(group.label))
else:
if not silent:
print("NO DBNODES TO IMPORT, SO NO GROUP CREATED")
print("NO DBNODES TO IMPORT, SO NO GROUP CREATED, IF IT DID ALREADY EXIST")

if not silent:
print("COMMITTING EVERYTHING...")
Expand Down

0 comments on commit cfcf213

Please sign in to comment.