Skip to content

Commit

Permalink
PsqlDos: Add entry_point_string argument to drop_hashes
Browse files Browse the repository at this point in the history
The `drop_hashes` function is a utility used by migrations to drop
the hashes of nodes. This is used when the logic to compute the node
hash is changed, and so existing hashes are no longer valid. The
migration will drop all hashes and emit a warning, suggesting that the
user run `verdi node rehash` to recompute the hashes.

In certain migrations, however, only the hashes of certain node types
may need to be dropped. Here the `entry_point_string` argument is added
that allows to filter the target set of nodes. The entry point
corresponding to the entry point is loaded, and from it the relevent
`node_type` is computed. This is used to create a `WHERE` clause in the
query. The warning is also made dynamic to include the correct option
for `verdi node rehash` to only recompute the hash for the target subset
of nodes.
  • Loading branch information
sphuber committed May 15, 2023
1 parent c447a1a commit c7a36fa
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions aiida/storage/psql_dos/migrations/utils/integrity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
###########################################################################
# pylint: disable=invalid-name
"""Methods to validate the database integrity and fix violations."""
from __future__ import annotations

from aiida.common.log import AIIDA_LOGGER

LOGGER = AIIDA_LOGGER.getChild(__file__)

WARNING_BORDER = '*' * 120

# These are all the entry points from the `aiida.calculations` category as registered with the AiiDA registry
Expand Down Expand Up @@ -160,7 +166,6 @@ def write_database_integrity_violation(results, headers, reason_message, action_

from tabulate import tabulate

from aiida.cmdline.utils import echo
from aiida.manage import configuration

global_profile = configuration.get_profile()
Expand All @@ -171,8 +176,7 @@ def write_database_integrity_violation(results, headers, reason_message, action_
action_message = 'nothing'

with NamedTemporaryFile(prefix='migration-', suffix='.log', dir='.', delete=False, mode='w+') as handle:
echo.echo('')
echo.echo_warning(
LOGGER.warning(
'\n{}\nFound one or multiple records that violate the integrity of the database\nViolation reason: {}\n'
'Performed action: {}\nViolators written to: {}\n{}\n'.format(
WARNING_BORDER, reason_message, action_message, handle.name, WARNING_BORDER
Expand All @@ -186,21 +190,45 @@ def write_database_integrity_violation(results, headers, reason_message, action_
handle.write(tabulate(results, headers))


def drop_hashes(conn, hash_extra_key):
def drop_hashes(conn, hash_extra_key: str, entry_point_string: str | None = None) -> None:
"""Drop hashes of nodes.
Print warning only if the DB actually contains nodes.
:param hash_extra_key: The key in the extras used to store the hash at the time of this migration.
:param entry_point_string: Optional entry point string of a node type to narrow the subset of nodes to reset. The
value should be a complete entry point string, e.g., ``aiida.node:process.calculation.calcjob`` to drop the hash
of all ``CalcJobNode`` rows.
"""
# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed
# pylint: disable=no-name-in-module,import-error
from sqlalchemy.sql import text

from aiida.cmdline.utils import echo
n_nodes = conn.execute(text("""SELECT count(*) FROM db_dbnode;""")).fetchall()[0][0]
if n_nodes > 0:
echo.echo_warning('Invalidating the hashes of all nodes. Please run "verdi rehash".', bold=True)
from aiida.orm.utils.node import get_type_string_from_class
from aiida.plugins import load_entry_point_from_string

if entry_point_string is not None:
entry_point = load_entry_point_from_string(entry_point_string)
node_type = get_type_string_from_class(entry_point.__module__, entry_point.__name__)
else:
node_type = None

if node_type:
statement_count = text(f"SELECT count(*) FROM db_dbnode WHERE node_type = '{node_type}';")
statement_update = text(
f"UPDATE db_dbnode SET extras = extras #- '{{{hash_extra_key}}}'::text[] WHERE node_type = '{node_type}';"
)
else:
statement_count = text('SELECT count(*) FROM db_dbnode;')
statement_update = text(f"UPDATE db_dbnode SET extras = extras #- '{{{hash_extra_key}}}'::text[];")

node_count = conn.execute(statement_count).fetchall()[0][0]

if node_count > 0:
if entry_point_string:
msg = f'Invalidating the hashes of certain nodes. Please run `verdi rehash -p {entry_point_string}`.'
else:
msg = 'Invalidating the hashes of all nodes. Please run `verdi rehash`.'
LOGGER.warning(msg)

statement = text(f"UPDATE db_dbnode SET extras = extras #- '{{{hash_extra_key}}}'::text[];")
conn.execute(statement)
conn.execute(statement_update)

0 comments on commit c7a36fa

Please sign in to comment.