Skip to content

Commit

Permalink
PsqlDosBackend: Fix Node.store excepting when inside a transaction (
Browse files Browse the repository at this point in the history
#6125)

Calling `Node.store` with the `PsqlDosBackend` would except whenever
inside a transaction, for example, when iterating over a `QueryBuilder`
result, which opens a transaction.

The reason is that the node implementation of the `PsqlDosBackend`, the
`SqlaNode.store` method calls `commit` on the session. This closes the
current transaction, and so when it is then used again, for example in
the next iteration of the builder results, an exception is raised by
sqlalchemy complaining that the transaction was closed.

The solution is that `SqlaNode.store` should only commit if it is not
inside a nested transaction, otherwise it should simply flush the
addition of the node to the session such that automatically generated
primary keys are populated.

A similar problem was addressed in the `add_nodes` and `remove_nodes`
methods of the `SqlaGroup` class. These would also call `commit` at the
end, regardless of whether they are called within an open transaction.
  • Loading branch information
sphuber committed Sep 22, 2023
1 parent 34be3b6 commit 624dcd9
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 32 deletions.
2 changes: 0 additions & 2 deletions aiida/orm/implementation/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,11 @@ def add_incoming(self, source: 'BackendNode', link_type, link_label):
def store( # pylint: disable=arguments-differ
self: BackendNodeType,
links: Optional[Sequence['LinkTriple']] = None,
with_transaction: bool = True,
clean: bool = True
) -> BackendNodeType:
"""Store the node in the database.
:param links: optional links to add before storing
:param with_transaction: if False, do not use a transaction because the caller will already have opened one.
:param clean: boolean, if True, will clean the attributes and extras before attempting to store
"""

Expand Down
25 changes: 10 additions & 15 deletions aiida/orm/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,10 @@ def mtime(self) -> datetime.datetime:
"""
return self.backend_entity.mtime

def store_all(self, with_transaction: bool = True) -> 'Node':
def store_all(self) -> 'Node':
"""Store the node, together with all input links.
Unstored nodes from cached incoming linkswill also be stored.
:parameter with_transaction: if False, do not use a transaction because the caller will already have opened one.
"""
if self.is_stored:
raise exceptions.ModificationNotAllowed(f'Node<{self.pk}> is already stored')
Expand All @@ -440,20 +438,18 @@ def store_all(self, with_transaction: bool = True) -> 'Node':

for link_triple in self.base.links.incoming_cache:
if not link_triple.node.is_stored:
link_triple.node.store(with_transaction=with_transaction)
link_triple.node.store()

return self.store(with_transaction)
return self.store()

def store(self, with_transaction: bool = True) -> 'Node': # pylint: disable=arguments-differ
def store(self) -> 'Node': # pylint: disable=arguments-differ
"""Store the node in the database while saving its attributes and repository directory.
After being called attributes cannot be changed anymore! Instead, extras can be changed only AFTER calling
this store() function.
:note: After successful storage, those links that are in the cache, and for which also the parent node is
already stored, will be automatically stored. The others will remain unstored.
:parameter with_transaction: if False, do not use a transaction because the caller will already have opened one.
"""
from aiida.manage.caching import get_use_cache

Expand All @@ -477,26 +473,25 @@ def store(self, with_transaction: bool = True) -> 'Node': # pylint: disable=arg
same_node = self.base.caching._get_same_node() if use_cache else None # pylint: disable=protected-access

if same_node is not None:
self._store_from_cache(same_node, with_transaction=with_transaction)
self._store_from_cache(same_node)
else:
self._store(with_transaction=with_transaction, clean=True)
self._store(clean=True)

if self.backend.autogroup.is_to_be_grouped(self):
group = self.backend.autogroup.get_or_create_group()
group.add_nodes(self)

return self

def _store(self, with_transaction: bool = True, clean: bool = True) -> 'Node':
def _store(self, clean: bool = True) -> 'Node':
"""Store the node in the database while saving its attributes and repository directory.
:param with_transaction: if False, do not use a transaction because the caller will already have opened one.
:param clean: boolean, if True, will clean the attributes and extras before attempting to store
"""
self.base.repository._store() # pylint: disable=protected-access

links = self.base.links.incoming_cache
self._backend_entity.store(links, with_transaction=with_transaction, clean=clean)
self._backend_entity.store(links, clean=clean)

self.base.links.incoming_cache = []
self.base.caching.rehash()
Expand All @@ -514,7 +509,7 @@ def _verify_are_parents_stored(self) -> None:
f'Cannot store because source node of link triple {link_triple} is not stored'
)

def _store_from_cache(self, cache_node: 'Node', with_transaction: bool) -> None:
def _store_from_cache(self, cache_node: 'Node') -> None:
"""Store this node from an existing cache node.
.. note::
Expand Down Expand Up @@ -542,7 +537,7 @@ def _store_from_cache(self, cache_node: 'Node', with_transaction: bool) -> None:
if key != Sealable.SEALED_KEY:
self.base.attributes.set(key, value)

self._store(with_transaction=with_transaction, clean=False)
self._store(clean=False)
self._add_outputs_from_cache(cache_node)
self.base.extras.set('_aiida_cached_from', cache_node.uuid)

Expand Down
15 changes: 5 additions & 10 deletions aiida/storage/psql_dos/orm/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def label(self, label):
try:
self.model.save()
except Exception:
raise UniquenessError(f'a group of the same type with the label {label} already exists') \
from Exception
raise UniquenessError(f'a group of the same type with the label {label} already exists') from Exception

@property
def description(self):
Expand Down Expand Up @@ -104,12 +103,6 @@ def pk(self):
def uuid(self):
return str(self.model.uuid)

def __int__(self):
if not self.is_stored:
return None

return self._dbnode.id # pylint: disable=no-member

@property
def is_stored(self):
return self.pk is not None
Expand Down Expand Up @@ -220,7 +213,8 @@ def check_node(given_node):
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))

# Commit everything as up till now we've just flushed
session.commit()
if not session.in_nested_transaction():
session.commit()

def remove_nodes(self, nodes, **kwargs):
"""Remove a node or a set of nodes from the group.
Expand Down Expand Up @@ -268,7 +262,8 @@ def check_node(node):
statement = table.delete().where(clause)
session.execute(statement)

session.commit()
if not session.in_nested_transaction():
session.commit()


class SqlaGroupCollection(BackendGroupCollection):
Expand Down
7 changes: 5 additions & 2 deletions aiida/storage/psql_dos/orm/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def clean_values(self):
self.model.attributes = clean_value(self.model.attributes)
self.model.extras = clean_value(self.model.extras)

def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ
def store(self, links=None, clean=True): # pylint: disable=arguments-differ
session = self.backend.get_session()

if clean:
Expand All @@ -223,12 +223,15 @@ def store(self, links=None, with_transaction=True, clean=True): # pylint: disab
for link_triple in links:
self._add_link(*link_triple)

if with_transaction:
if not session.in_nested_transaction():
try:
session.commit()
except SQLAlchemyError:
session.rollback()
raise
else:
# Make sure the new addition is flushed to, e.g., populate automatic primary keys.
session.flush()

return self

Expand Down
2 changes: 2 additions & 0 deletions aiida/storage/psql_dos/orm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def save(self):
self.session.add(self._model)
if not self._in_transaction():
self.session.commit()
else:
self.session.flush()
except IntegrityError as exception:
self.session.rollback()
raise exceptions.IntegrityError(str(exception))
Expand Down
2 changes: 1 addition & 1 deletion tests/orm/nodes/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ def test_store_from_cache(self, tmp_path):
data.store()

clone = data.clone()
clone._store_from_cache(data, with_transaction=True) # pylint: disable=protected-access
clone._store_from_cache(data) # pylint: disable=protected-access

assert clone.is_stored
assert clone.base.caching.get_cache_source() == data.uuid
Expand Down
25 changes: 23 additions & 2 deletions tests/orm/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,10 +1495,31 @@ def test_iterall_with_mutation(self):
assert orm.load_node(pk).get_extra('key') == 'value'

@pytest.mark.usefixtures('aiida_profile_clean')
@pytest.mark.skip('enable when https://github.com/aiidateam/aiida-core/issues/5802 is fixed')
def test_iterall_with_store(self):
"""Test that nodes can be stored while being iterated using ``QueryBuilder.iterall``.
This is a regression test for https://github.com/aiidateam/aiida-core/issues/5802 .
"""
count = 10
pks = []
pk_clones = []

for _ in range(count):
node = orm.Int().store()
pks.append(node.pk)

# Ensure that batch size is smaller than the total rows yielded
for [node] in orm.QueryBuilder().append(orm.Int).iterall(batch_size=2):
clone = orm.Int(node.value).store()
pk_clones.append(clone.pk)

for pk, pk_clone in zip(pks, sorted(pk_clones)):
assert orm.load_node(pk) == orm.load_node(pk_clone)

@pytest.mark.usefixtures('aiida_profile_clean')
def test_iterall_with_store_group(self):
"""Test that nodes can be stored and added to groups while being iterated using ``QueryBuilder.iterall``.
This is a regression test for https://github.com/aiidateam/aiida-core/issues/5802 .
"""
count = 10
Expand All @@ -1510,7 +1531,7 @@ def test_iterall_with_store(self):
pks.append(node.pk)

# Ensure that batch size is smaller than the total rows yielded
for [node] in orm.QueryBuilder().append(orm.Data).iterall(batch_size=2):
for [node] in orm.QueryBuilder().append(orm.Int).iterall(batch_size=2):
clone = copy.deepcopy(node)
clone.store()
pks_clone.append((clone.value, clone.pk))
Expand Down

0 comments on commit 624dcd9

Please sign in to comment.