Skip to content
19 changes: 19 additions & 0 deletions airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,20 @@ class GroupCommand(NamedTuple):
args=(ARG_OUTPUT, ARG_VERBOSE),
),
)
STATE_STORE_COMMANDS = (
ActionCommand(
name="cleanup-task-states",
help="Remove expired task state rows (MetastoreStateBackend only)",
description=(
"Reads [state_store] default_retention_days from config and deletes task_state rows "
"older than the configured threshold. Only applies when MetastoreStateBackend is configured; "
"custom backends are skipped. Use --dry-run to preview without deleting."
),
func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup_task_states"),
args=(ARG_DB_DRY_RUN, ARG_VERBOSE),
),
)

DB_COMMANDS = (
ActionCommand(
name="check-migrations",
Expand Down Expand Up @@ -2102,6 +2116,11 @@ class GroupCommand(NamedTuple):
help="Display providers",
subcommands=PROVIDERS_COMMANDS,
),
GroupCommand(
name="state-store",
help="Manage task and asset state storage",
Comment thread
Lee-W marked this conversation as resolved.
subcommands=STATE_STORE_COMMANDS,
),
ActionCommand(
name="rotate-fernet-key",
func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"),
Expand Down
49 changes: 49 additions & 0 deletions airflow-core/src/airflow/cli/commands/state_store_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging

from airflow.state import get_state_backend
from airflow.state.metastore import MetastoreStateBackend

log = logging.getLogger(__name__)

# Other state operations (list, get, delete per key) will be added here in the future.


def cleanup_task_states(args) -> None:
"""Remove expired task state rows (MetastoreStateBackend only)."""
backend = get_state_backend()

if not isinstance(backend, MetastoreStateBackend):
print("Custom backend configured — skipping cleanup (not supported).")
return

if args.dry_run:
summary = backend._summary_dry_run_()
expired = summary["expired"]
if not expired:
print("Nothing to delete.")
return
print(f"Would delete {len(expired)} task state row(s):\n")
for dag_id, run_id, task_id, map_index, key in expired:
print(f" Dag {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}")
return

log.info("Running task state cleanup")
backend.cleanup()
18 changes: 18 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3025,6 +3025,24 @@ state_store:
type: string
example: "mypackage.state.CustomStateBackend"
default: "airflow.state.metastore.MetastoreStateBackend"
default_retention_days:
description: |
Number of days to retain task state after their last update.
Rows older than this are removed by the scheduler's periodic cleanup.
This config does not affect asset_state rows.
Comment thread
amoghrajesh marked this conversation as resolved.
Set to 0 to disable time-based cleanup entirely.
version_added: 3.3.0
type: integer
example: "7"
default: "30"
state_cleanup_batch_size:
description: |
Number of rows deleted per batch during cleanup. Defaults to 0 (no batching).
Tune this on deployments with large task_state tables to improve performance per transaction.
version_added: 3.3.0
type: integer
example: "10000"
default: "0"

profiling:
description: |
Expand Down
32 changes: 31 additions & 1 deletion airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,20 @@
from itertools import groupby
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy import CTE, and_, case, delete, exists, func, inspect, or_, select, text, tuple_, update
from sqlalchemy import (
CTE,
and_,
case,
delete,
exists,
func,
inspect,
or_,
select,
text,
tuple_,
update,
)
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload
from sqlalchemy.sql import expression
Expand Down Expand Up @@ -70,6 +83,7 @@
TaskInletAssetReference,
TaskOutletAssetReference,
)
from airflow.models.asset_state import AssetStateModel
from airflow.models.backfill import Backfill, BackfillDagRun
from airflow.models.callback import Callback, CallbackType, ExecutorCallback
from airflow.models.dag import DagModel
Expand Down Expand Up @@ -3092,6 +3106,7 @@ def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:

self._orphan_unreferenced_assets(orphan_query, session=session)
self._activate_referenced_assets(activate_query, session=session)
self._cleanup_orphaned_asset_state(session=session)

@staticmethod
def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> None:
Expand Down Expand Up @@ -3200,6 +3215,21 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
session.add(warning)
existing_warned_dag_ids.add(warning.dag_id)

@staticmethod
def _cleanup_orphaned_asset_state(*, session: Session) -> None:
"""
Delete asset_state rows for assets no longer active in any Dag.

When _orphan_unreferenced_assets removes an asset from asset_active, its
asset_state rows become unreachable — no task can write to them anymore.
This runs in the same pass as asset orphanage to keep the table clean.
"""
active_asset_ids = select(AssetModel.id).join(
AssetActive,
(AssetActive.name == AssetModel.name) & (AssetActive.uri == AssetModel.uri),
)
session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids)))

def _executor_to_workloads(
self,
workloads: Iterable[SchedulerWorkload],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def upgrade():
)
op.create_table(
"task_state",
sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
sa.Column("dag_run_id", sa.Integer(), nullable=False),
sa.Column("task_id", StringID(), nullable=False),
sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False),
Expand All @@ -65,20 +66,24 @@ def upgrade():
sa.Column("run_id", StringID(), nullable=False),
sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), nullable=False),
sa.Column("updated_at", UtcDateTime(), nullable=False),
sa.Column("expires_at", UtcDateTime(), nullable=True),
sa.ForeignKeyConstraint(
["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey", ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"),
sa.PrimaryKeyConstraint("id", name="task_state_pkey"),
sa.UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"),
)
with op.batch_alter_table("task_state", schema=None) as batch_op:
batch_op.create_index(
"idx_task_state_lookup", ["dag_id", "run_id", "task_id", "map_index"], unique=False
)
batch_op.create_index("idx_task_state_expires_at", ["expires_at"], unique=False)


def downgrade():
"""Unapply add task_state and asset_state tables."""
with op.batch_alter_table("task_state", schema=None) as batch_op:
batch_op.drop_index("idx_task_state_expires_at")
batch_op.drop_index("idx_task_state_lookup")

op.drop_table("task_state")
Expand Down
21 changes: 15 additions & 6 deletions airflow-core/src/airflow/models/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from datetime import datetime

from sqlalchemy import ForeignKeyConstraint, Index, Integer, PrimaryKeyConstraint, String, Text
from sqlalchemy import ForeignKeyConstraint, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.orm import Mapped, mapped_column

Expand All @@ -39,24 +39,33 @@ class TaskStateModel(Base):

__tablename__ = "task_state"

dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False, primary_key=True)
task_id: Mapped[str] = mapped_column(StringID(), nullable=False, primary_key=True)
map_index: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, server_default="-1")
key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)

dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False)
task_id: Mapped[str] = mapped_column(StringID(), nullable=False)
map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default="-1")
key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False)

dag_id: Mapped[str] = mapped_column(StringID(), nullable=False)
run_id: Mapped[str] = mapped_column(StringID(), nullable=False)

value: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT, "mysql"), nullable=False)
updated_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False)
# Optional override for early expiry. When set, garbage collection deletes this row when
# expires_at < now(), even if updated_at is recent. NULL means no early expiry —
# the row is still cleaned up by the global `updated_at + default_retention_days` check.
# Populated via task_state.set(retention_days=N) for keys that should expire differently
# than the deployment wide default.
expires_at: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True)

__table_args__ = (
PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"),
UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"),
ForeignKeyConstraint(
["dag_run_id"],
["dag_run.id"],
name="task_state_dag_run_fkey",
ondelete="CASCADE",
),
Index("idx_task_state_lookup", "dag_id", "run_id", "task_id", "map_index"),
Index("idx_task_state_expires_at", "expires_at"),
)
73 changes: 70 additions & 3 deletions airflow-core/src/airflow/state/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
# under the License.
from __future__ import annotations

from datetime import datetime, timedelta
from typing import TYPE_CHECKING

import structlog
from sqlalchemy import delete, select

from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, TaskScope
from airflow._shared.timezones import timezone
from airflow.configuration import conf
from airflow.models.asset_state import AssetStateModel
from airflow.models.dagrun import DagRun
from airflow.models.task_state import TaskStateModel
from airflow.typing_compat import assert_never
from airflow.utils.session import NEW_SESSION, create_session_async, provide_session
from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session
from airflow.utils.sqlalchemy import get_dialect_name

if TYPE_CHECKING:
Expand All @@ -38,6 +41,21 @@
from sqlalchemy.orm import Session


log = structlog.get_logger(__name__)


def _compute_expires_at(now: datetime) -> datetime | None:
"""
Return the expiry timestamp for a new task state row based on config.

Returns None if default_retention_days is 0 (never expires).
"""
retention_days = conf.getint("state_store", "default_retention_days")
if retention_days <= 0:
return None
return now + timedelta(days=retention_days)


def _build_upsert_stmt(
dialect: str | None,
model: type,
Expand Down Expand Up @@ -176,6 +194,7 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se
if dag_run_id is None:
raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}")
now = timezone.utcnow()
expires_at = _compute_expires_at(now)
values = dict(
dag_run_id=dag_run_id,
dag_id=scope.dag_id,
Expand All @@ -185,13 +204,14 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se
key=key,
value=value,
updated_at=now,
expires_at=expires_at,
)
stmt = _build_upsert_stmt(
get_dialect_name(session),
TaskStateModel,
["dag_run_id", "task_id", "map_index", "key"],
values,
dict(value=value, updated_at=now),
dict(value=value, updated_at=now, expires_at=expires_at),
)
session.execute(stmt)

Expand Down Expand Up @@ -252,6 +272,51 @@ def _clear_asset_state(self, scope: AssetScope, *, session: Session) -> None:
)
)

def cleanup(self) -> None:
"""
Remove expired task state rows.

``expires_at`` is set at write time on every ``set()`` call, so cleanup is a single
``WHERE expires_at < now()`` pass. Rows with ``expires_at=NULL`` (default_retention_days=0)
are never deleted. Batching is configurable via ``[state_store] state_cleanup_batch_size``.
"""
batch_size = conf.getint("state_store", "state_cleanup_batch_size")
now = timezone.utcnow()

def _delete_batched(where_clause) -> int:
total = 0
with create_session() as session:
while True:
id_query = select(TaskStateModel.id).where(where_clause)
if batch_size > 0:
id_query = id_query.limit(batch_size)
ids = session.scalars(id_query).all()
if not ids:
break
session.execute(delete(TaskStateModel).where(TaskStateModel.id.in_(ids)))
session.commit()
total += len(ids)
if batch_size <= 0 or len(ids) < batch_size:
break
return total

deleted = _delete_batched(TaskStateModel.expires_at < now)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it’d be a good idea if the actual expiration is calculated on the fly instead. If I’m understaing correctly, this currently relies on the expires_at column being correctly updated whenever update_at is updated (if the former is not set explicitly). This seems a bit fragile.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expires_at is set once at write time in every set() call and is never updated independently of the row — there's no dependency on updated_at being in sync with it. If you call set() again on the same key, the upsert recalculates and overwrites both updated_at and expires_at together atomically.

One legitimate edge case you may be pointing at: if a user starts with default_retention_days=0, then later raises it to 30 days, those old NULL rows won't be picked up by the current WHERE expires_at < now()pass. We can add a second pass WHEREexpires_atIS NULL ANDupdated_at < now - default_retention_days` for that case. How does that sound?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by a second pass? Where would this happen? (In abstract it sounds like a plan; it’s similar to how the next run needs to be recalculated when you change the dag schedule definition.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A second pass would be something like WHERE expires_at IS NULL AND updated_at < now - default_retention_days, to catch rows that were written when default_retention_days=0 but the config was later raised.

Something like:

# Pass 1: code right now
deleted_expired = _delete_batched(TaskStateModel.expires_at < now)

# Pass 2: rows with NULL expires_at that are stale under the current global default
if default_retention_days > 0:
    cutoff = now - timedelta(days=default_retention_days)
    deleted_stale = _delete_batched(
        TaskStateModel.expires_at.is_(None) & (TaskStateModel.updated_at < cutoff)
    )

It would run in the same airflow state-store cleanup command

But on thinking more, I do not think that it is needed. expires_at=NULL is an explicit signal — either default_retention_days=0 was set, or retention_days=0 was passed at write time. Both mean "keep this row forever." Retroactively deleting them on a config change would violate what was promised at write time.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should add a command to change retention of existing tis; I feel some would need it, either because they didn’t know about the feature previously or want to change policy entirely. Or maybe those people can just manually delete the tis in another way anyway?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, i would hold off for now — people can run a direct sql statements if needed, and any new key(s) that gets written will automatically pick up the new default_retention_days configured. If there is a clear demand for it we can add something like airflow state-store set-expiry when needed, but feels premature before we know how common that use case is.

Copy link
Copy Markdown
Contributor Author

@amoghrajesh amoghrajesh May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uranusjr now that I think of it, they can also delete it by calling a context["task_state"].clear(all_map_indices=True), so its possible in case they want it

log.info("Deleted expired task_state rows", rows_deleted=deleted)

def _summary_dry_run_(self) -> dict[str, list]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _summary_dry_run_(self) -> dict[str, list]:
def _summary_dry_run(self) -> dict[str, list]:

typo?

"""Return rows that would be deleted by cleanup() without deleting anything."""
now = timezone.utcnow()
cols = (
TaskStateModel.dag_id,
TaskStateModel.run_id,
TaskStateModel.task_id,
TaskStateModel.map_index,
TaskStateModel.key,
)
with create_session() as session:
expired = session.execute(select(*cols).where(TaskStateModel.expires_at < now)).all()
return {"expired": list(expired)}

async def _aget_task_state(self, scope: TaskScope, key: str, *, session: AsyncSession) -> str | None:
row = await session.scalar(
select(TaskStateModel).where(
Expand All @@ -276,6 +341,7 @@ async def _aset_task_state(
if dag_run_id is None:
raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}")
now = timezone.utcnow()
expires_at = _compute_expires_at(now)
values = dict(
dag_run_id=dag_run_id,
dag_id=scope.dag_id,
Expand All @@ -285,14 +351,15 @@ async def _aset_task_state(
key=key,
value=value,
updated_at=now,
expires_at=expires_at,
)
# get_dialect_name expects a sync Session; sync_session is the underlying Session the async wrapper delegates to
stmt = _build_upsert_stmt(
get_dialect_name(session.sync_session),
TaskStateModel,
["dag_run_id", "task_id", "map_index", "key"],
values,
dict(value=value, updated_at=now),
dict(value=value, updated_at=now, expires_at=expires_at),
)
await session.execute(stmt)

Expand Down
Loading
Loading