Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions airflow-core/src/airflow/utils/db_manager.py
Comment thread
anmolxlight marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,37 @@ def create_db_from_orm(self):
command.stamp(config, "head")
self.log.info("%s tables have been created from the ORM", self.__class__.__name__)

def _has_existing_manager_tables(self) -> bool:
"""Return whether any table managed by this DB manager already exists."""
inspector = inspect(self.session.get_bind())
table_names_by_schema: dict[str | None, set[str]] = {}
for table in self.metadata.tables.values():
table_names_by_schema.setdefault(table.schema, set()).add(table.name)

for schema, table_names in table_names_by_schema.items():
existing_table_names = set(inspector.get_table_names(schema=schema))
if table_names.intersection(existing_table_names):
return True
return False

def _get_base_revision(self, config=None) -> str:
"""Return the first/base Alembic revision for this DB manager."""
script = self.get_script_object(config)
for revision in script.walk_revisions():
if revision.down_revision is None:
return revision.revision
raise RuntimeError(f"No base revision found for {self.__class__.__name__}")

def _stamp_base_revision(self, config) -> None:
"""Stamp the database to this DB manager's base Alembic revision."""
base_revision = self._get_base_revision(config)
self.log.info(
"%s tables already exist without an Alembic version; stamping base revision %s before upgrade",
self.__class__.__name__,
base_revision,
)
command.stamp(config, base_revision)

def drop_tables(self, connection):
if not self.supports_table_dropping:
return
Expand Down Expand Up @@ -189,10 +220,15 @@ def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False, u
self._release_metadata_locks_if_needed()

if not current_revision and not to_revision and not use_migration_files and not show_sql_only:
self.create_db_from_orm()
return
if self._has_existing_manager_tables():
config = self.get_alembic_config()
self._stamp_base_revision(config)
else:
self.create_db_from_orm()
return
Comment thread
jscheffl marked this conversation as resolved.
else:
config = self.get_alembic_config()

config = self.get_alembic_config()
command.upgrade(config, revision=to_revision or "heads", sql=show_sql_only)
self.log.info("Migrated the %s database", self.__class__.__name__)

Expand Down
50 changes: 49 additions & 1 deletion airflow-core/tests/unit/utils/test_db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock

import pytest
from sqlalchemy import Column, Integer, MetaData, Table

from airflow.models import Base
from airflow.utils.db_manager import BaseDBManager, RunDBManager
Expand Down Expand Up @@ -50,6 +51,17 @@ def downgrade(self, to_revision, from_revision=None, show_sql_only=False):
alembic_command.downgrade(config, revision=to_revision, sql=show_sql_only)


legacy_metadata = MetaData()
Table("external_legacy_table", legacy_metadata, Column("id", Integer, primary_key=True))


class LegacyTablesDBManager(BaseDBManager):
metadata = legacy_metadata
version_table_name = "legacy_alembic_version"
migration_dir = "legacy_migration_dir"
alembic_file = "legacy_alembic.ini"


class LegacySignatureExternalManager:
initdb_calls = 0
upgradedb_calls = 0
Expand Down Expand Up @@ -154,15 +166,51 @@ def test_upgrade(self, mock_current_revision, mock_alembic_cmd, mock_alembic_con
assert "Upgrading the MockDBManager database" in caplog.text

@mock.patch.object(BaseDBManager, "create_db_from_orm")
@mock.patch.object(BaseDBManager, "_has_existing_manager_tables", return_value=False)
@mock.patch.object(BaseDBManager, "get_current_revision")
def test_upgrade_empty_db_without_migration_files_uses_create_db_from_orm(
self, mock_current_revision, mock_create_db_from_orm, session
self, mock_current_revision, mock_has_existing_manager_tables, mock_create_db_from_orm, session
):
mock_current_revision.return_value = None
manager = MockDBManager(session)
manager.upgradedb()
mock_has_existing_manager_tables.assert_called_once()
mock_create_db_from_orm.assert_called_once()

@mock.patch.object(BaseDBManager, "get_current_revision", return_value=None)
@mock.patch.object(BaseDBManager, "create_db_from_orm")
@mock.patch.object(BaseDBManager, "get_alembic_config")
@mock.patch.object(BaseDBManager, "get_script_object")
@mock.patch("airflow.utils.db_manager.inspect")
@mock.patch("alembic.command.stamp")
@mock.patch("alembic.command.upgrade")
def test_upgrade_with_existing_manager_tables_without_version_stamps_base_then_runs_migrations(
self,
mock_upgrade,
mock_stamp,
mock_inspect,
mock_get_script_object,
mock_get_alembic_config,
mock_create_db_from_orm,
mock_get_current_revision,
session,
):
config = object()
mock_get_alembic_config.return_value = config
base_revision = mock.Mock(revision="base-revision", down_revision=None)
mock_get_script_object.return_value.walk_revisions.return_value = [
mock.Mock(revision="head-revision", down_revision="base-revision"),
base_revision,
]
mock_inspect.return_value.get_table_names.return_value = ["external_legacy_table"]

manager = LegacyTablesDBManager(session)
manager.upgradedb()

mock_create_db_from_orm.assert_not_called()
mock_stamp.assert_called_once_with(config, "base-revision")
mock_upgrade.assert_called_once_with(config, revision="heads", sql=False)

@mock.patch.object(BaseDBManager, "get_script_object")
@mock.patch.object(BaseDBManager, "get_current_revision")
def test_check_migration(self, mock_script_obj, mock_current_revision, session):
Expand Down
Loading