Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 65 additions & 47 deletions airflow-core/src/airflow/utils/db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any

from sqlalchemy import and_, column, func, inspect, select, table, text
from sqlalchemy import and_, column, func, inspect, literal_column, select, table, text
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import aliased
Expand All @@ -57,6 +57,7 @@
logger = logging.getLogger(__name__)

ARCHIVE_TABLE_PREFIX = "_airflow_deleted__"
_BASE_TABLE_ALIAS = "base"
# Archived tables created by DB migrations
ARCHIVED_TABLES_FROM_DB_MIGRATIONS = [
"_xcom_archive" # Table created by the AF 2 -> 3.0.0 migration when the XComs had pickled values
Expand Down Expand Up @@ -241,60 +242,78 @@ def _do_delete(
else:
print("Performing Delete...")

if skip_archive:
_delete_directly(query=limited_query, orm_model=orm_model, session=session)
continue

# using bulk delete
# create a new table and copy the rows there
timestamp_str = re.sub(r"[^\d]", "", timezone.utcnow().isoformat())[:14]
target_table_name = f"{ARCHIVE_TABLE_PREFIX}{orm_model.name}__{timestamp_str}{suffix}"
print(f"Moving data to table {target_table_name}")
target_table = None

try:
if dialect_name == "mysql":
# MySQL with replication needs this split into two queries, so just do it for all MySQL
# ERROR 1786 (HY000): Statement violates GTID consistency: CREATE TABLE ... SELECT.
session.execute(text(f"CREATE TABLE {target_table_name} LIKE {orm_model.name}"))
metadata = reflect_tables([target_table_name], session)
target_table = metadata.tables[target_table_name]
insert_stm = target_table.insert().from_select(target_table.c, limited_query)
logger.debug("insert statement:\n%s", insert_stm.compile())
session.execute(insert_stm)
else:
stmt = CreateTableAs(target_table_name, limited_query.selectable)
logger.debug("ctas query:\n%s", stmt.compile())
session.execute(stmt)
session.commit()

# delete the rows from the old table
metadata = reflect_tables([orm_model.name, target_table_name], session)
source_table = metadata.tables[orm_model.name]

if dialect_name == "mysql":
# MySQL with replication needs this split into two queries, so just do it for all MySQL
# ERROR 1786 (HY000): Statement violates GTID consistency: CREATE TABLE ... SELECT.
session.execute(text(f"CREATE TABLE {target_table_name} LIKE {orm_model.name}"))
metadata = reflect_tables([target_table_name], session)
target_table = metadata.tables[target_table_name]
logger.debug("rows moved; purging from %s", source_table.name)
if dialect_name == "sqlite":
pk_cols = source_table.primary_key.columns
delete = source_table.delete().where(
tuple_(*pk_cols).in_(
select(*[target_table.c[x.name] for x in source_table.primary_key.columns])
)
)
else:
delete = source_table.delete().where(
and_(*[col == target_table.c[col.name] for col in source_table.primary_key.columns])
insert_stm = target_table.insert().from_select(target_table.c, limited_query)
logger.debug("insert statement:\n%s", insert_stm.compile())
session.execute(insert_stm)
else:
stmt = CreateTableAs(target_table_name, limited_query.selectable)
logger.debug("ctas query:\n%s", stmt.compile())
session.execute(stmt)
session.commit()

# delete the rows from the old table
metadata = reflect_tables([orm_model.name, target_table_name], session)
source_table = metadata.tables[orm_model.name]
target_table = metadata.tables[target_table_name]
logger.debug("rows moved; purging from %s", source_table.name)
if dialect_name == "sqlite":
pk_cols = source_table.primary_key.columns
delete = source_table.delete().where(
tuple_(*pk_cols).in_(
select(*[target_table.c[x.name] for x in source_table.primary_key.columns])
)
logger.debug("delete statement:\n%s", delete.compile())
session.execute(delete)
session.commit()

except BaseException as e:
raise e
finally:
if target_table is not None and skip_archive:
bind = session.get_bind()
target_table.drop(bind=bind)
session.commit()
)
else:
delete = source_table.delete().where(
and_(*[col == target_table.c[col.name] for col in source_table.primary_key.columns])
)
logger.debug("delete statement:\n%s", delete.compile())
session.execute(delete)
session.commit()

print("Finished Performing Delete")


def _delete_directly(*, query: Select, orm_model: Base, session: Session) -> None:
metadata = reflect_tables([orm_model.name], session)
source_table = metadata.tables[orm_model.name]
pk_cols = list(source_table.primary_key.columns)
if not pk_cols:
raise ValueError(f"Table {orm_model.name} has no primary key columns available for cleanup.")

pk_query = query.with_only_columns(
*[literal_column(f"{_BASE_TABLE_ALIAS}.{col.name}").label(col.name) for col in pk_cols]
)
rows_to_delete = pk_query.subquery("rows_to_delete")
delete_filter = (
source_table.c[pk_cols[0].name].in_(select(rows_to_delete.c[pk_cols[0].name]))
if len(pk_cols) == 1
else tuple_(*[source_table.c[col.name] for col in pk_cols]).in_(
select(*[rows_to_delete.c[col.name] for col in pk_cols])
)
)
delete = source_table.delete().where(delete_filter)
logger.debug("direct delete statement:\n%s", delete.compile())
session.execute(delete)
session.commit()


def _subquery_keep_last(
*,
recency_column,
Expand Down Expand Up @@ -343,9 +362,8 @@ def _build_query(
exclude_dag_ids: list[str] | None = None,
**kwargs,
) -> Select:
base_table_alias = "base"
base_table = aliased(orm_model, name=base_table_alias)
query = select(text(f"{base_table_alias}.*")).select_from(base_table)
base_table = aliased(orm_model, name=_BASE_TABLE_ALIAS)
query = select(text(f"{_BASE_TABLE_ALIAS}.*")).select_from(base_table)
base_table_recency_col = base_table.c[recency_column.name]
conditions = [base_table_recency_col < clean_before_timestamp]

Expand Down
59 changes: 50 additions & 9 deletions airflow-core/tests/unit/utils/test_db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from contextlib import suppress
from contextlib import contextmanager, suppress
from importlib import import_module
from io import StringIO
from pathlib import Path
Expand All @@ -26,7 +26,7 @@

import pendulum
import pytest
from sqlalchemy import func, inspect, select, text
from sqlalchemy import event, func, inspect, select, text
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.ext.declarative import DeclarativeMeta

Expand Down Expand Up @@ -67,6 +67,21 @@
pytestmark = pytest.mark.db_test


@contextmanager
def capture_sql_statements(session):
statements: list[str] = []
bind = session.get_bind()

def capture(_conn, _cursor, statement, _parameters, _context, _executemany):
statements.append(statement)

event.listen(bind, "before_cursor_execute", capture)
try:
yield statements
finally:
event.remove(bind, "before_cursor_execute", capture)


@pytest.fixture(autouse=True)
def clean_database():
"""Fixture that cleans the database before and after every test."""
Expand Down Expand Up @@ -447,7 +462,7 @@ def test_cleanup_with_dag_id_filtering(self, dag_ids, exclude_dag_ids, expected_
)
def test__skip_archive(self, skip_archive, expected_archives):
"""
Verify that running cleanup_table with drops the archives when requested.
Verify that running cleanup_table drops the archives when requested.

Archived tables from DB migration should be kept when skip_archive is True.
"""
Expand All @@ -474,13 +489,39 @@ def test__skip_archive(self, skip_archive, expected_archives):
assert session.scalar(select(func.count()).select_from(model)) == 5
assert len(_get_archived_table_names(["dag_run"], session)) == expected_archives

@patch("airflow.utils.db.reflect_tables")
def test_skip_archive_failure_will_remove_table(self, reflect_tables_mock):
"""
Verify that running cleanup_table with skip_archive = True, and failure happens.
@pytest.mark.parametrize(
"batch_size", [pytest.param(None, id="single_delete"), pytest.param(2, id="batched")]
)
def test_skip_archive_does_not_create_archive_table(self, batch_size):
"""Verify skip_archive avoids archive-table SQL instead of creating then dropping archives."""
base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone("UTC"))
create_tis(base_date=base_date, num_tis=10)

The archive table should be removed from db if any exception.
"""
with create_session() as session:
for name in _get_archived_table_names(["dag_run"], session):
session.execute(text(f"DROP TABLE IF EXISTS {name}"))
session.commit()

clean_before_date = base_date.add(days=5)
with capture_sql_statements(session) as statements:
_cleanup_table(
**config_dict["dag_run"].__dict__,
clean_before_timestamp=clean_before_date,
dry_run=False,
session=session,
table_names=["dag_run"],
skip_archive=True,
batch_size=batch_size,
)

model = config_dict["dag_run"].orm_model
assert session.scalar(select(func.count()).select_from(model)) == 5
assert len(_get_archived_table_names(["dag_run"], session)) == 0
assert [statement for statement in statements if ARCHIVE_TABLE_PREFIX in statement] == []

@patch("airflow.utils.db_cleanup.reflect_tables")
def test_skip_archive_failure_does_not_create_archive_table(self, reflect_tables_mock):
"""Verify skip_archive failures do not leave archive tables behind."""
reflect_tables_mock.side_effect = SQLAlchemyError("Deletion failed")
base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone("UTC"))
num_tis = 10
Expand Down
Loading