-
Notifications
You must be signed in to change notification settings - Fork 17k
AIP-103: Adding periodic task state garbage collection and retention support #66463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b644ce6
cdc4237
df379c5
7f401d4
f52ce27
7427d04
151dee5
66081d0
58dba88
6b6968e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
|
|
||
| from airflow.state import get_state_backend | ||
| from airflow.state.metastore import MetastoreStateBackend | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
| # Other state operations (list, get, delete per key) will be added here in the future. | ||
|
|
||
|
|
||
| def cleanup_task_states(args) -> None: | ||
| """Remove expired task state rows (MetastoreStateBackend only).""" | ||
| backend = get_state_backend() | ||
|
|
||
| if not isinstance(backend, MetastoreStateBackend): | ||
| print("Custom backend configured — skipping cleanup (not supported).") | ||
| return | ||
|
|
||
| if args.dry_run: | ||
| summary = backend._summary_dry_run_() | ||
| expired = summary["expired"] | ||
| if not expired: | ||
| print("Nothing to delete.") | ||
| return | ||
| print(f"Would delete {len(expired)} task state row(s):\n") | ||
| for dag_id, run_id, task_id, map_index, key in expired: | ||
| print(f" Dag {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}") | ||
| return | ||
|
|
||
| log.info("Running task state cleanup") | ||
| backend.cleanup() |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -17,17 +17,20 @@ | |||||
| # under the License. | ||||||
| from __future__ import annotations | ||||||
|
|
||||||
| from datetime import datetime, timedelta | ||||||
| from typing import TYPE_CHECKING | ||||||
|
|
||||||
| import structlog | ||||||
| from sqlalchemy import delete, select | ||||||
|
|
||||||
| from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, TaskScope | ||||||
| from airflow._shared.timezones import timezone | ||||||
| from airflow.configuration import conf | ||||||
| from airflow.models.asset_state import AssetStateModel | ||||||
| from airflow.models.dagrun import DagRun | ||||||
| from airflow.models.task_state import TaskStateModel | ||||||
| from airflow.typing_compat import assert_never | ||||||
| from airflow.utils.session import NEW_SESSION, create_session_async, provide_session | ||||||
| from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session | ||||||
| from airflow.utils.sqlalchemy import get_dialect_name | ||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
|
|
@@ -38,6 +41,21 @@ | |||||
| from sqlalchemy.orm import Session | ||||||
|
|
||||||
|
|
||||||
| log = structlog.get_logger(__name__) | ||||||
|
|
||||||
|
|
||||||
| def _compute_expires_at(now: datetime) -> datetime | None: | ||||||
| """ | ||||||
| Return the expiry timestamp for a new task state row based on config. | ||||||
|
|
||||||
| Returns None if default_retention_days is 0 (never expires). | ||||||
| """ | ||||||
| retention_days = conf.getint("state_store", "default_retention_days") | ||||||
| if retention_days <= 0: | ||||||
| return None | ||||||
| return now + timedelta(days=retention_days) | ||||||
|
|
||||||
|
|
||||||
| def _build_upsert_stmt( | ||||||
| dialect: str | None, | ||||||
| model: type, | ||||||
|
|
@@ -176,6 +194,7 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se | |||||
| if dag_run_id is None: | ||||||
| raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}") | ||||||
| now = timezone.utcnow() | ||||||
| expires_at = _compute_expires_at(now) | ||||||
| values = dict( | ||||||
| dag_run_id=dag_run_id, | ||||||
| dag_id=scope.dag_id, | ||||||
|
|
@@ -185,13 +204,14 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se | |||||
| key=key, | ||||||
| value=value, | ||||||
| updated_at=now, | ||||||
| expires_at=expires_at, | ||||||
| ) | ||||||
| stmt = _build_upsert_stmt( | ||||||
| get_dialect_name(session), | ||||||
| TaskStateModel, | ||||||
| ["dag_run_id", "task_id", "map_index", "key"], | ||||||
| values, | ||||||
| dict(value=value, updated_at=now), | ||||||
| dict(value=value, updated_at=now, expires_at=expires_at), | ||||||
| ) | ||||||
| session.execute(stmt) | ||||||
|
|
||||||
|
|
@@ -252,6 +272,51 @@ def _clear_asset_state(self, scope: AssetScope, *, session: Session) -> None: | |||||
| ) | ||||||
| ) | ||||||
|
|
||||||
| def cleanup(self) -> None: | ||||||
| """ | ||||||
| Remove expired task state rows. | ||||||
|
|
||||||
| ``expires_at`` is set at write time on every ``set()`` call, so cleanup is a single | ||||||
| ``WHERE expires_at < now()`` pass. Rows with ``expires_at=NULL`` (default_retention_days=0) | ||||||
| are never deleted. Batching is configurable via ``[state_store] state_cleanup_batch_size``. | ||||||
| """ | ||||||
| batch_size = conf.getint("state_store", "state_cleanup_batch_size") | ||||||
| now = timezone.utcnow() | ||||||
|
|
||||||
| def _delete_batched(where_clause) -> int: | ||||||
| total = 0 | ||||||
| with create_session() as session: | ||||||
| while True: | ||||||
| id_query = select(TaskStateModel.id).where(where_clause) | ||||||
| if batch_size > 0: | ||||||
| id_query = id_query.limit(batch_size) | ||||||
| ids = session.scalars(id_query).all() | ||||||
| if not ids: | ||||||
| break | ||||||
| session.execute(delete(TaskStateModel).where(TaskStateModel.id.in_(ids))) | ||||||
| session.commit() | ||||||
| total += len(ids) | ||||||
| if batch_size <= 0 or len(ids) < batch_size: | ||||||
| break | ||||||
| return total | ||||||
|
|
||||||
| deleted = _delete_batched(TaskStateModel.expires_at < now) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it’d be a good idea if the actual expiration is calculated on the fly instead. If I’m understaing correctly, this currently relies on the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
One legitimate edge case you may be pointing at: if a user starts with
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean by a second pass? Where would this happen? (In abstract it sounds like a plan; it’s similar to how the next run needs to be recalculated when you change the dag schedule definition.)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A second pass would be something like Something like: # Pass 1: code right now
deleted_expired = _delete_batched(TaskStateModel.expires_at < now)
# Pass 2: rows with NULL expires_at that are stale under the current global default
if default_retention_days > 0:
cutoff = now - timedelta(days=default_retention_days)
deleted_stale = _delete_batched(
TaskStateModel.expires_at.is_(None) & (TaskStateModel.updated_at < cutoff)
)It would run in the same But on thinking more, I do not think that it is needed.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should add a command to change retention of existing tis; I feel some would need it, either because they didn’t know about the feature previously or want to change policy entirely. Or maybe those people can just manually delete the tis in another way anyway?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TBH, i would hold off for now — people can run a direct sql statements if needed, and any new key(s) that gets written will automatically pick up the new
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @uranusjr now that I think of it, they can also delete it by calling a |
||||||
| log.info("Deleted expired task_state rows", rows_deleted=deleted) | ||||||
|
|
||||||
| def _summary_dry_run_(self) -> dict[str, list]: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
typo? |
||||||
| """Return rows that would be deleted by cleanup() without deleting anything.""" | ||||||
| now = timezone.utcnow() | ||||||
| cols = ( | ||||||
| TaskStateModel.dag_id, | ||||||
| TaskStateModel.run_id, | ||||||
| TaskStateModel.task_id, | ||||||
| TaskStateModel.map_index, | ||||||
| TaskStateModel.key, | ||||||
| ) | ||||||
| with create_session() as session: | ||||||
| expired = session.execute(select(*cols).where(TaskStateModel.expires_at < now)).all() | ||||||
| return {"expired": list(expired)} | ||||||
|
|
||||||
| async def _aget_task_state(self, scope: TaskScope, key: str, *, session: AsyncSession) -> str | None: | ||||||
| row = await session.scalar( | ||||||
| select(TaskStateModel).where( | ||||||
|
|
@@ -276,6 +341,7 @@ async def _aset_task_state( | |||||
| if dag_run_id is None: | ||||||
| raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}") | ||||||
| now = timezone.utcnow() | ||||||
| expires_at = _compute_expires_at(now) | ||||||
| values = dict( | ||||||
| dag_run_id=dag_run_id, | ||||||
| dag_id=scope.dag_id, | ||||||
|
|
@@ -285,14 +351,15 @@ async def _aset_task_state( | |||||
| key=key, | ||||||
| value=value, | ||||||
| updated_at=now, | ||||||
| expires_at=expires_at, | ||||||
| ) | ||||||
| # get_dialect_name expects a sync Session; sync_session is the underlying Session the async wrapper delegates to | ||||||
| stmt = _build_upsert_stmt( | ||||||
| get_dialect_name(session.sync_session), | ||||||
| TaskStateModel, | ||||||
| ["dag_run_id", "task_id", "map_index", "key"], | ||||||
| values, | ||||||
| dict(value=value, updated_at=now), | ||||||
| dict(value=value, updated_at=now, expires_at=expires_at), | ||||||
| ) | ||||||
| await session.execute(stmt) | ||||||
|
|
||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.