Skip to content

Commit

Permalink
ORM: Replace .collection(backend) with .get_collection(backend)
Browse files Browse the repository at this point in the history
The classproperty will try to load the default storage backend first,
before recreating the collection with the specified backend. Not only is
this inefficient as the collection is recreated if the `backend` is not
the current default one, but it can also fail in situations where there
is no default profile is available and the caller wants to directly
specify the backend.
  • Loading branch information
sphuber committed Sep 5, 2023
1 parent 305f1db commit bac2152
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 27 deletions.
12 changes: 6 additions & 6 deletions aiida/orm/computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def get_authinfo(self, user: 'User') -> 'AuthInfo':
from . import authinfos

try:
authinfo = authinfos.AuthInfo.collection(self.backend).get(dbcomputer_id=self.pk, aiidauser_id=user.pk)
authinfo = authinfos.AuthInfo.get_collection(self.backend).get(dbcomputer_id=self.pk, aiidauser_id=user.pk)
except exceptions.NotExistent as exc:
raise exceptions.NotExistent(
f'Computer `{self.label}` (ID={self.pk}) not configured for user `{user.get_short_name()}` '
Expand All @@ -586,7 +586,7 @@ def is_configured(self) -> bool:
:return: Boolean, ``True`` if the computer is configured for the current default user, ``False`` otherwise.
"""
return self.is_user_configured(users.User.collection(self.backend).get_default())
return self.is_user_configured(users.User.get_collection(self.backend).get_default())

def is_user_configured(self, user: 'User') -> bool:
"""
Expand Down Expand Up @@ -636,8 +636,8 @@ def get_transport(self, user: Optional['User'] = None) -> 'Transport':
"""
from . import authinfos # pylint: disable=cyclic-import

user = user or users.User.collection(self.backend).get_default()
authinfo = authinfos.AuthInfo.collection(self.backend).get(dbcomputer=self, aiidauser=user)
user = user or users.User.get_collection(self.backend).get_default()
authinfo = authinfos.AuthInfo.get_collection(self.backend).get(dbcomputer=self, aiidauser=user)
return authinfo.get_transport()

def get_transport_class(self) -> Type['Transport']:
Expand Down Expand Up @@ -670,7 +670,7 @@ def configure(self, user: Optional['User'] = None, **kwargs: Any) -> 'AuthInfo':
from . import authinfos

transport_cls = self.get_transport_class()
user = user or users.User.collection(self.backend).get_default()
user = user or users.User.get_collection(self.backend).get_default()
valid_keys = set(transport_cls.get_valid_auth_params())

if not set(kwargs.keys()).issubset(valid_keys):
Expand All @@ -696,7 +696,7 @@ def get_configuration(self, user: Optional['User'] = None) -> Dict[str, Any]:
:param user: the user to to get the configuration for, otherwise default user
"""
user = user or users.User.collection(self.backend).get_default()
user = user or users.User.get_collection(self.backend).get_default()

try:
authinfo = self.get_authinfo(user)
Expand Down
16 changes: 8 additions & 8 deletions aiida/orm/nodes/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def add(self, content: str, user: t.Optional[User] = None) -> Comment:
:param user: the user to associate with the comment, will use default if not supplied
:return: the newly created comment
"""
user = user or User.collection(self._node.backend).get_default()
user = user or User.get_collection(self._node.backend).get_default()
return Comment(node=self._node, user=user, content=content).store()

def get(self, identifier: int) -> Comment:
Expand All @@ -33,17 +33,17 @@ def get(self, identifier: int) -> Comment:
:raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment
:return: the comment
"""
return Comment.collection(self._node.backend).get(dbnode_id=self._node.pk, id=identifier)
return Comment.get_collection(self._node.backend).get(dbnode_id=self._node.pk, id=identifier)

def all(self) -> list[Comment]:
"""Return a sorted list of comments for this node.
:return: the list of comments, sorted by pk
"""
return Comment.collection(self._node.backend
).find(filters={'dbnode_id': self._node.pk}, order_by=[{
'id': 'asc'
}])
return Comment.get_collection(self._node.backend
).find(filters={'dbnode_id': self._node.pk}, order_by=[{
'id': 'asc'
}])

def update(self, identifier: int, content: str) -> None:
"""Update the content of an existing comment.
Expand All @@ -53,12 +53,12 @@ def update(self, identifier: int, content: str) -> None:
:raise aiida.common.NotExistent: if the comment with the given id does not exist
:raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment
"""
comment = Comment.collection(self._node.backend).get(dbnode_id=self._node.pk, id=identifier)
comment = Comment.get_collection(self._node.backend).get(dbnode_id=self._node.pk, id=identifier)
comment.set_content(content)

def remove(self, identifier: int) -> None:
"""Delete an existing comment.
:param identifier: the comment pk
"""
Comment.collection(self._node.backend).delete(identifier)
Comment.get_collection(self._node.backend).delete(identifier)
7 changes: 6 additions & 1 deletion aiida/orm/nodes/data/array/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,9 +1800,14 @@ def get_bands_and_parents_structure(args, backend=None):
from aiida import orm
from aiida.common import timezone

if backend:
user = orm.User.get_collection(backend).get_default()
else:
user = orm.User.collection.get_default()

q_build = orm.QueryBuilder(backend=backend)
if args.all_users is False:
q_build.append(orm.User, tag='creator', filters={'email': orm.User.collection.get_default().email})
q_build.append(orm.User, tag='creator', filters={'email': user.email})
else:
q_build.append(orm.User, tag='creator')

Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/nodes/data/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,4 @@ def _validate(self):
raise ValidationError('Remote computer not set.')

def get_authinfo(self):
return AuthInfo.collection(self.backend).get(dbcomputer=self.computer, aiidauser=self.user)
return AuthInfo.get_collection(self.backend).get(dbcomputer=self.computer, aiidauser=self.user)
12 changes: 8 additions & 4 deletions aiida/orm/nodes/data/upf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T
:param stop_if_existing: if True, check for the md5 of the files and, if the file already exists in the DB, raises a
MultipleObjectsError. If False, simply adds the existing UPFData node to the group.
"""
# pylint: disable=too-many-locals,too-many-branches
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
import os

from aiida import orm
Expand All @@ -101,10 +101,14 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T

nfiles = len(filenames)

automatic_user = orm.User.collection.get_default()
group, group_created = orm.UpfFamily.collection.get_or_create(label=group_label, user=automatic_user)
if backend:
default_user = orm.User.get_collection(backend).get_default()
group, group_created = orm.UpfFamily.get_collection(backend).get_or_create(label=group_label, user=default_user)
else:
default_user = orm.User.collection.get_default()
group, group_created = orm.UpfFamily.collection.get_or_create(label=group_label, user=default_user)

if group.user.email != automatic_user.email:
if group.user.email != default_user.email:
raise UniquenessError(
'There is already a UpfFamily group with label {}'
', but it belongs to user {}, therefore you '
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/utils/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def emit(self, record):
try:
try:
backend = record.__dict__.pop('backend')
orm.Log.collection(backend).create_entry_from_record(record)
orm.Log.get_collection(backend).create_entry_from_record(record)
except KeyError:
# The backend should be set. We silently absorb this error
pass
Expand Down
6 changes: 5 additions & 1 deletion aiida/orm/utils/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def get_calcjob_remote_paths( # pylint: disable=too-many-locals
:param only_not_cleaned: only include calcjobs whose workdir have not been cleaned
:return: mapping of computer uuid and list of remote folder
"""
# pylint: disable=too-many-branches
from datetime import timedelta

from aiida import orm
Expand All @@ -74,7 +75,10 @@ def get_calcjob_remote_paths( # pylint: disable=too-many-locals
filters_remote = {}

if user is None:
user = orm.User.collection.get_default()
if backend:
user = orm.User.get_collection(backend).get_default()
else:
user = orm.User.collection.get_default()

if computers is not None:
filters_computer['id'] = {'in': [computer.pk for computer in computers]}
Expand Down
2 changes: 1 addition & 1 deletion aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def get_unreferenced_keyset(self, check_consistency: bool = True) -> Set[str]:
repository = self.get_repository()

keyset_repository = set(repository.list_objects())
keyset_database = set(orm.Node.collection(self).iter_repo_keys())
keyset_database = set(orm.Node.get_collection(self).iter_repo_keys())

if check_consistency:
keyset_missing = keyset_database - keyset_repository
Expand Down
4 changes: 2 additions & 2 deletions aiida/tools/archive/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def create_archive(
if test_run:
EXPORT_LOGGER.report('Test Run: Stopping before archive creation')
keys = set(
orm.Node.collection(backend).iter_repo_keys(
orm.Node.get_collection(backend).iter_repo_keys(
filters={'id': {
'in': list(entity_ids[EntityTypes.NODE])
}}, batch_size=batch_size
Expand Down Expand Up @@ -593,7 +593,7 @@ def _stream_repo_files(
) -> None:
"""Collect all repository object keys from the nodes, then stream the files to the archive."""
keys = set(
orm.Node.collection(backend).iter_repo_keys(filters={'id': {
orm.Node.get_collection(backend).iter_repo_keys(filters={'id': {
'in': list(node_ids)
}}, batch_size=batch_size)
)
Expand Down
4 changes: 2 additions & 2 deletions aiida/tools/visualization/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,9 @@ def _load_node(self, node: int | str | orm.Node) -> orm.Node:
:param node: node or node pk/uuid
"""
if isinstance(node, int):
return orm.Node.collection(self._backend).get(pk=node)
return orm.Node.get_collection(self._backend).get(pk=node)
if isinstance(node, str):
return orm.Node.collection(self._backend).get(uuid=node)
return orm.Node.get_collection(self._backend).get(uuid=node)
return node

def add_node(
Expand Down

0 comments on commit bac2152

Please sign in to comment.