Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions airflow-core/src/airflow/utils/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import os
from typing import TYPE_CHECKING

from alembic import command
from sqlalchemy import inspect

from airflow import settings
Expand Down Expand Up @@ -50,6 +49,12 @@ def _callable_accepts_use_migration_files(callable_) -> bool:
)


def _get_alembic_command():
from alembic import command

return command


class BaseDBManager(LoggingMixin):
"""Abstract Base DB manager for external DBs."""

Expand Down Expand Up @@ -126,7 +131,7 @@ def create_db_from_orm(self):
engine = self.session.get_bind().engine
self.metadata.create_all(engine)
config = self.get_alembic_config()
command.stamp(config, "head")
_get_alembic_command().stamp(config, "head")
self.log.info("%s tables have been created from the ORM", self.__class__.__name__)

def drop_tables(self, connection):
Expand Down Expand Up @@ -180,7 +185,7 @@ def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False, u
return

config = self.get_alembic_config()
command.upgrade(config, revision=to_revision or "heads", sql=show_sql_only)
_get_alembic_command().upgrade(config, revision=to_revision or "heads", sql=show_sql_only)
self.log.info("Migrated the %s database", self.__class__.__name__)

def downgrade(self, to_revision, from_revision=None, show_sql_only=False):
Expand Down
19 changes: 19 additions & 0 deletions airflow-core/tests/unit/utils/test_db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

import importlib
import sys
from contextlib import nullcontext
from unittest import mock

Expand Down Expand Up @@ -97,6 +99,23 @@ def _create_run_db_manager(*managers):


class TestBaseDBManager:
def test_importing_db_manager_does_not_eagerly_import_alembic(self):
original_db_manager_module = sys.modules.get("airflow.utils.db_manager")

try:
sys.modules.pop("airflow.utils.db_manager", None)
for module_name in list(sys.modules):
if module_name == "alembic" or module_name.startswith("alembic."):
sys.modules.pop(module_name)

module = importlib.import_module("airflow.utils.db_manager")

assert hasattr(module, "RunDBManager")
assert not any(name == "alembic" or name.startswith("alembic.") for name in sys.modules)
finally:
if original_db_manager_module is not None:
sys.modules["airflow.utils.db_manager"] = original_db_manager_module

@mock.patch.object(BaseDBManager, "get_alembic_config")
@mock.patch.object(BaseDBManager, "get_current_revision")
@mock.patch.object(BaseDBManager, "create_db_from_orm")
Expand Down
Loading