diff --git a/airflow-core/src/airflow/utils/db_manager.py b/airflow-core/src/airflow/utils/db_manager.py index 8c4eb11a0d4fe..57173341e88ba 100644 --- a/airflow-core/src/airflow/utils/db_manager.py +++ b/airflow-core/src/airflow/utils/db_manager.py @@ -142,6 +142,37 @@ def create_db_from_orm(self): command.stamp(config, "head") self.log.info("%s tables have been created from the ORM", self.__class__.__name__) + def _has_existing_manager_tables(self) -> bool: + """Return whether any table managed by this DB manager already exists.""" + inspector = inspect(self.session.get_bind()) + table_names_by_schema: dict[str | None, set[str]] = {} + for table in self.metadata.tables.values(): + table_names_by_schema.setdefault(table.schema, set()).add(table.name) + + for schema, table_names in table_names_by_schema.items(): + existing_table_names = set(inspector.get_table_names(schema=schema)) + if table_names.intersection(existing_table_names): + return True + return False + + def _get_base_revision(self, config=None) -> str: + """Return the first/base Alembic revision for this DB manager.""" + script = self.get_script_object(config) + for revision in script.walk_revisions(): + if revision.down_revision is None: + return revision.revision + raise RuntimeError(f"No base revision found for {self.__class__.__name__}") + + def _stamp_base_revision(self, config) -> None: + """Stamp the database to this DB manager's base Alembic revision.""" + base_revision = self._get_base_revision(config) + self.log.info( + "%s tables already exist without an Alembic version; stamping base revision %s before upgrade", + self.__class__.__name__, + base_revision, + ) + command.stamp(config, base_revision) + def drop_tables(self, connection): if not self.supports_table_dropping: return @@ -189,10 +220,15 @@ def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False, u self._release_metadata_locks_if_needed() if not current_revision and not to_revision and not use_migration_files and not show_sql_only: - self.create_db_from_orm() - return + if self._has_existing_manager_tables(): + config = self.get_alembic_config() + self._stamp_base_revision(config) + else: + self.create_db_from_orm() + return + else: + config = self.get_alembic_config() - config = self.get_alembic_config() command.upgrade(config, revision=to_revision or "heads", sql=show_sql_only) self.log.info("Migrated the %s database", self.__class__.__name__) diff --git a/airflow-core/tests/unit/utils/test_db_manager.py b/airflow-core/tests/unit/utils/test_db_manager.py index fa57fd9554ad6..1f8f3e7423af7 100644 --- a/airflow-core/tests/unit/utils/test_db_manager.py +++ b/airflow-core/tests/unit/utils/test_db_manager.py @@ -22,6 +22,7 @@ from unittest import mock import pytest +from sqlalchemy import Column, Integer, MetaData, Table from airflow.models import Base from airflow.utils.db_manager import BaseDBManager, RunDBManager @@ -50,6 +51,17 @@ def downgrade(self, to_revision, from_revision=None, show_sql_only=False): alembic_command.downgrade(config, revision=to_revision, sql=show_sql_only) +legacy_metadata = MetaData() +Table("external_legacy_table", legacy_metadata, Column("id", Integer, primary_key=True)) + + +class LegacyTablesDBManager(BaseDBManager): + metadata = legacy_metadata + version_table_name = "legacy_alembic_version" + migration_dir = "legacy_migration_dir" + alembic_file = "legacy_alembic.ini" + + class LegacySignatureExternalManager: initdb_calls = 0 upgradedb_calls = 0 @@ -154,15 +166,51 @@ def test_upgrade(self, mock_current_revision, mock_alembic_cmd, mock_alembic_con assert "Upgrading the MockDBManager database" in caplog.text @mock.patch.object(BaseDBManager, "create_db_from_orm") + @mock.patch.object(BaseDBManager, "_has_existing_manager_tables", return_value=False) @mock.patch.object(BaseDBManager, "get_current_revision") def test_upgrade_empty_db_without_migration_files_uses_create_db_from_orm( - self, mock_current_revision, mock_create_db_from_orm, session + self, mock_current_revision, mock_has_existing_manager_tables, mock_create_db_from_orm, session ): mock_current_revision.return_value = None manager = MockDBManager(session) manager.upgradedb() + mock_has_existing_manager_tables.assert_called_once() mock_create_db_from_orm.assert_called_once() + @mock.patch.object(BaseDBManager, "get_current_revision", return_value=None) + @mock.patch.object(BaseDBManager, "create_db_from_orm") + @mock.patch.object(BaseDBManager, "get_alembic_config") + @mock.patch.object(BaseDBManager, "get_script_object") + @mock.patch("airflow.utils.db_manager.inspect") + @mock.patch("alembic.command.stamp") + @mock.patch("alembic.command.upgrade") + def test_upgrade_with_existing_manager_tables_without_version_stamps_base_then_runs_migrations( + self, + mock_upgrade, + mock_stamp, + mock_inspect, + mock_get_script_object, + mock_get_alembic_config, + mock_create_db_from_orm, + mock_get_current_revision, + session, + ): + config = object() + mock_get_alembic_config.return_value = config + base_revision = mock.Mock(revision="base-revision", down_revision=None) + mock_get_script_object.return_value.walk_revisions.return_value = [ + mock.Mock(revision="head-revision", down_revision="base-revision"), + base_revision, + ] + mock_inspect.return_value.get_table_names.return_value = ["external_legacy_table"] + + manager = LegacyTablesDBManager(session) + manager.upgradedb() + + mock_create_db_from_orm.assert_not_called() + mock_stamp.assert_called_once_with(config, "base-revision") + mock_upgrade.assert_called_once_with(config, revision="heads", sql=False) + @mock.patch.object(BaseDBManager, "get_script_object") @mock.patch.object(BaseDBManager, "get_current_revision") def test_check_migration(self, mock_script_obj, mock_current_revision, session):