diff --git a/airflow-core/src/airflow/utils/db_cleanup.py b/airflow-core/src/airflow/utils/db_cleanup.py index 0c605b8d6bd7f..65b851f45cdfb 100644 --- a/airflow-core/src/airflow/utils/db_cleanup.py +++ b/airflow-core/src/airflow/utils/db_cleanup.py @@ -32,7 +32,7 @@ from types import SimpleNamespace from typing import TYPE_CHECKING, Any -from sqlalchemy import and_, column, func, inspect, select, table, text +from sqlalchemy import and_, column, func, inspect, literal_column, select, table, text from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import aliased @@ -57,6 +57,7 @@ logger = logging.getLogger(__name__) ARCHIVE_TABLE_PREFIX = "_airflow_deleted__" +_BASE_TABLE_ALIAS = "base" # Archived tables created by DB migrations ARCHIVED_TABLES_FROM_DB_MIGRATIONS = [ "_xcom_archive" # Table created by the AF 2 -> 3.0.0 migration when the XComs had pickled values @@ -241,60 +242,78 @@ def _do_delete( else: print("Performing Delete...") + if skip_archive: + _delete_directly(query=limited_query, orm_model=orm_model, session=session) + continue + # using bulk delete # create a new table and copy the rows there timestamp_str = re.sub(r"[^\d]", "", timezone.utcnow().isoformat())[:14] target_table_name = f"{ARCHIVE_TABLE_PREFIX}{orm_model.name}__{timestamp_str}{suffix}" print(f"Moving data to table {target_table_name}") - target_table = None - - try: - if dialect_name == "mysql": - # MySQL with replication needs this split into two queries, so just do it for all MySQL - # ERROR 1786 (HY000): Statement violates GTID consistency: CREATE TABLE ... SELECT. - session.execute(text(f"CREATE TABLE {target_table_name} LIKE {orm_model.name}")) - metadata = reflect_tables([target_table_name], session) - target_table = metadata.tables[target_table_name] - insert_stm = target_table.insert().from_select(target_table.c, limited_query) - logger.debug("insert statement:\n%s", insert_stm.compile()) - session.execute(insert_stm) - else: - stmt = CreateTableAs(target_table_name, limited_query.selectable) - logger.debug("ctas query:\n%s", stmt.compile()) - session.execute(stmt) - session.commit() - - # delete the rows from the old table - metadata = reflect_tables([orm_model.name, target_table_name], session) - source_table = metadata.tables[orm_model.name] + + if dialect_name == "mysql": + # MySQL with replication needs this split into two queries, so just do it for all MySQL + # ERROR 1786 (HY000): Statement violates GTID consistency: CREATE TABLE ... SELECT. + session.execute(text(f"CREATE TABLE {target_table_name} LIKE {orm_model.name}")) + metadata = reflect_tables([target_table_name], session) target_table = metadata.tables[target_table_name] - logger.debug("rows moved; purging from %s", source_table.name) - if dialect_name == "sqlite": - pk_cols = source_table.primary_key.columns - delete = source_table.delete().where( - tuple_(*pk_cols).in_( - select(*[target_table.c[x.name] for x in source_table.primary_key.columns]) - ) - ) - else: - delete = source_table.delete().where( - and_(*[col == target_table.c[col.name] for col in source_table.primary_key.columns]) + insert_stm = target_table.insert().from_select(target_table.c, limited_query) + logger.debug("insert statement:\n%s", insert_stm.compile()) + session.execute(insert_stm) + else: + stmt = CreateTableAs(target_table_name, limited_query.selectable) + logger.debug("ctas query:\n%s", stmt.compile()) + session.execute(stmt) + session.commit() + + # delete the rows from the old table + metadata = reflect_tables([orm_model.name, target_table_name], session) + source_table = metadata.tables[orm_model.name] + target_table = metadata.tables[target_table_name] + logger.debug("rows moved; purging from %s", source_table.name) + if dialect_name == "sqlite": + pk_cols = source_table.primary_key.columns + delete = source_table.delete().where( + tuple_(*pk_cols).in_( + select(*[target_table.c[x.name] for x in source_table.primary_key.columns]) ) - logger.debug("delete statement:\n%s", delete.compile()) - session.execute(delete) - session.commit() - - except BaseException as e: - raise e - finally: - if target_table is not None and skip_archive: - bind = session.get_bind() - target_table.drop(bind=bind) - session.commit() + ) + else: + delete = source_table.delete().where( + and_(*[col == target_table.c[col.name] for col in source_table.primary_key.columns]) + ) + logger.debug("delete statement:\n%s", delete.compile()) + session.execute(delete) + session.commit() print("Finished Performing Delete") +def _delete_directly(*, query: Select, orm_model: Base, session: Session) -> None: + metadata = reflect_tables([orm_model.name], session) + source_table = metadata.tables[orm_model.name] + pk_cols = list(source_table.primary_key.columns) + if not pk_cols: + raise ValueError(f"Table {orm_model.name} has no primary key columns available for cleanup.") + + pk_query = query.with_only_columns( + *[literal_column(f"{_BASE_TABLE_ALIAS}.{col.name}").label(col.name) for col in pk_cols] + ) + rows_to_delete = pk_query.subquery("rows_to_delete") + delete_filter = ( + source_table.c[pk_cols[0].name].in_(select(rows_to_delete.c[pk_cols[0].name])) + if len(pk_cols) == 1 + else tuple_(*[source_table.c[col.name] for col in pk_cols]).in_( + select(*[rows_to_delete.c[col.name] for col in pk_cols]) + ) + ) + delete = source_table.delete().where(delete_filter) + logger.debug("direct delete statement:\n%s", delete.compile()) + session.execute(delete) + session.commit() + + def _subquery_keep_last( *, recency_column, @@ -343,9 +362,8 @@ def _build_query( exclude_dag_ids: list[str] | None = None, **kwargs, ) -> Select: - base_table_alias = "base" - base_table = aliased(orm_model, name=base_table_alias) - query = select(text(f"{base_table_alias}.*")).select_from(base_table) + base_table = aliased(orm_model, name=_BASE_TABLE_ALIAS) + query = select(text(f"{_BASE_TABLE_ALIAS}.*")).select_from(base_table) base_table_recency_col = base_table.c[recency_column.name] conditions = [base_table_recency_col < clean_before_timestamp] diff --git a/airflow-core/tests/unit/utils/test_db_cleanup.py b/airflow-core/tests/unit/utils/test_db_cleanup.py index b0d7bb50dc0e9..bd1aea84cfa78 100644 --- a/airflow-core/tests/unit/utils/test_db_cleanup.py +++ b/airflow-core/tests/unit/utils/test_db_cleanup.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from contextlib import suppress +from contextlib import contextmanager, suppress from importlib import import_module from io import StringIO from pathlib import Path @@ -26,7 +26,7 @@ import pendulum import pytest -from sqlalchemy import func, inspect, select, text +from sqlalchemy import event, func, inspect, select, text from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.ext.declarative import DeclarativeMeta @@ -67,6 +67,21 @@ pytestmark = pytest.mark.db_test +@contextmanager +def capture_sql_statements(session): + statements: list[str] = [] + bind = session.get_bind() + + def capture(_conn, _cursor, statement, _parameters, _context, _executemany): + statements.append(statement) + + event.listen(bind, "before_cursor_execute", capture) + try: + yield statements + finally: + event.remove(bind, "before_cursor_execute", capture) + + @pytest.fixture(autouse=True) def clean_database(): """Fixture that cleans the database before and after every test.""" @@ -447,7 +462,7 @@ def test_cleanup_with_dag_id_filtering(self, dag_ids, exclude_dag_ids, expected_ ) def test__skip_archive(self, skip_archive, expected_archives): """ - Verify that running cleanup_table with drops the archives when requested. + Verify that running cleanup_table drops the archives when requested. Archived tables from DB migration should be kept when skip_archive is True. """ @@ -474,13 +489,39 @@ def test__skip_archive(self, skip_archive, expected_archives): assert session.scalar(select(func.count()).select_from(model)) == 5 assert len(_get_archived_table_names(["dag_run"], session)) == expected_archives - @patch("airflow.utils.db.reflect_tables") - def test_skip_archive_failure_will_remove_table(self, reflect_tables_mock): - """ - Verify that running cleanup_table with skip_archive = True, and failure happens. + @pytest.mark.parametrize( + "batch_size", [pytest.param(None, id="single_delete"), pytest.param(2, id="batched")] + ) + def test_skip_archive_does_not_create_archive_table(self, batch_size): + """Verify skip_archive avoids archive-table SQL instead of creating then dropping archives.""" + base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone("UTC")) + create_tis(base_date=base_date, num_tis=10) - The archive table should be removed from db if any exception. - """ + with create_session() as session: + for name in _get_archived_table_names(["dag_run"], session): + session.execute(text(f"DROP TABLE IF EXISTS {name}")) + session.commit() + + clean_before_date = base_date.add(days=5) + with capture_sql_statements(session) as statements: + _cleanup_table( + **config_dict["dag_run"].__dict__, + clean_before_timestamp=clean_before_date, + dry_run=False, + session=session, + table_names=["dag_run"], + skip_archive=True, + batch_size=batch_size, + ) + + model = config_dict["dag_run"].orm_model + assert session.scalar(select(func.count()).select_from(model)) == 5 + assert len(_get_archived_table_names(["dag_run"], session)) == 0 + assert [statement for statement in statements if ARCHIVE_TABLE_PREFIX in statement] == [] + + @patch("airflow.utils.db_cleanup.reflect_tables") + def test_skip_archive_failure_does_not_create_archive_table(self, reflect_tables_mock): + """Verify skip_archive failures do not leave archive tables behind.""" reflect_tables_mock.side_effect = SQLAlchemyError("Deletion failed") base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone("UTC")) num_tis = 10