Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PsqlDosMigrator: Remove hardcoding of table name in database reset #5781

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union

from disk_objectstore import Container
from sqlalchemy import table
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from aiida.common.exceptions import ClosedStorage, ConfigurationError, IntegrityError
Expand Down Expand Up @@ -170,18 +169,21 @@ def _clear(self) -> None:

super()._clear()

session = self.get_session()

with self.migrator_context(self._profile) as migrator:

with self.transaction():
for table_name in (
'db_dbgroup_dbnodes', 'db_dbgroup', 'db_dblink', 'db_dbnode', 'db_dblog', 'db_dbauthinfo',
'db_dbuser', 'db_dbcomputer'
):
session.execute(table(table_name).delete())
# First clear the contents of the database
with self.transaction() as session:

# Close the session otherwise the ``delete_tables`` call will hang as there will be an open connection
# to the PostgreSQL server and it will block the deletion and the command will hang.
Copy link
Member

@ltalirz ltalirz Nov 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this measure prevents connections opened by this particular process from interfering, but I guess there could be other connections open to the DB (daemon / REST API / verdi shell / ...)?

Is there a way we could detect this and exit with an error message instead of hanging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure to be honest. Not sure if sqlalchemy provides an API to get all open connection to a database, even those that are managed by it. I doubt it. We would probably have to go straight to psycopg and directly execute a postgres command. Anyway, for now, this method is only being called during unit testing, so it is not that likely.

self.get_session().close()
exclude_tables = [migrator.alembic_version_tbl_name, 'db_dbsetting']
migrator.delete_all_tables(exclude_tables=exclude_tables)

# Clear out all references to database model instances which are now invalid.
session.expunge_all()

# Now reset and reinitialise the repository
migrator.reset_repository()
migrator.initialise_repository()
repository_uuid = migrator.get_repository_uuid()
Expand Down
34 changes: 26 additions & 8 deletions aiida/storage/psql_dos/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from alembic.runtime.migration import MigrationContext, MigrationInfo
from alembic.script import ScriptDirectory
from disk_objectstore import Container
from sqlalchemy import String, Table, column, desc, insert, inspect, select, table
from sqlalchemy import MetaData, String, Table, column, desc, insert, inspect, select, table
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -272,13 +272,7 @@ def reset_database(self) -> None:

This will also destroy the settings table and so in order to use it again, it will have to be reinitialised.
"""
if inspect(self.connection).has_table(self.alembic_version_tbl_name):
for table_name in (
'db_dbgroup_dbnodes', 'db_dbgroup', 'db_dblink', 'db_dbnode', 'db_dblog', 'db_dbauthinfo', 'db_dbuser',
'db_dbcomputer', 'db_dbsetting'
):
self.connection.execute(table(table_name).delete())
self.connection.commit()
self.delete_all_tables(exclude_tables=[self.alembic_version_tbl_name])

def initialise_repository(self) -> None:
"""Initialise the repository."""
Expand Down Expand Up @@ -312,6 +306,30 @@ def initialise_database(self) -> None:
context.stamp(context.script, 'main@head')
self.connection.commit()

def delete_all_tables(self, *, exclude_tables: list[str] | None = None) -> None:
"""Delete all tables of the current database schema.

The tables are determined dynamically through reflection of the current schema version. Any other tables in the
database that are not part of the schema should remain unaffected.

:param exclude_tables: Optional list of table names that should not be deleted.
"""
exclude_tables = exclude_tables or []

if inspect(self.connection).has_table(self.alembic_version_tbl_name):

metadata = MetaData()
metadata.reflect(bind=self.connection)

# The ``sorted_tables`` property returns the tables sorted by their foreign-key dependencies, with those
# that are dependent on others first. Iterate over the list in reverse to ensure that the tables with
# the independent rows are deleted first.
for schema_table in reversed(metadata.sorted_tables):
if schema_table.name in exclude_tables:
continue
self.connection.execute(schema_table.delete())
self.connection.commit()

def migrate(self) -> None:
"""Migrate the storage for this profile to the head version.

Expand Down