Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,23 @@
_MAX_SERIALIZED_BYTES = 65535


class AssetStoreLastUpdatedBy(BaseModel):
"""Writer info for the last write to an asset store entry."""

kind: str
dag_id: str | None = None
run_id: str | None = None
task_id: str | None = None
map_index: int | None = None


class AssetStoreResponse(BaseModel):
"""A single asset store key/value pair with metadata."""

key: str
value: JsonValue
updated_at: datetime
last_updated_by: AssetStoreLastUpdatedBy | None = None


class AssetStoreCollectionResponse(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11524,6 +11524,36 @@ components:
- total_entries
title: AssetStoreCollectionResponse
description: All asset store entries for an asset.
AssetStoreLastUpdatedBy:
properties:
kind:
type: string
title: Kind
dag_id:
anyOf:
- type: string
- type: 'null'
title: Dag Id
run_id:
anyOf:
- type: string
- type: 'null'
title: Run Id
task_id:
anyOf:
- type: string
- type: 'null'
title: Task Id
map_index:
anyOf:
- type: integer
- type: 'null'
title: Map Index
type: object
required:
- kind
title: AssetStoreLastUpdatedBy
description: Writer info for the last write to an asset store entry.
AssetStoreResponse:
properties:
key:
Expand All @@ -11535,6 +11565,10 @@ components:
type: string
format: date-time
title: Updated At
last_updated_by:
anyOf:
- $ref: '#/components/schemas/AssetStoreLastUpdatedBy'
- type: 'null'
type: object
required:
- key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
from fastapi import Depends, HTTPException, status
from sqlalchemy import select

from airflow._shared.state import AssetScope
from airflow._shared.state import AssetScope, AssetStoreWriterKind
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.asset_store import (
AssetStoreBody,
AssetStoreCollectionResponse,
AssetStoreLastUpdatedBy,
AssetStoreResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
Expand Down Expand Up @@ -73,6 +74,11 @@ def list_asset_store(
AssetStoreModel.key,
AssetStoreModel.value,
AssetStoreModel.updated_at,
AssetStoreModel.last_updated_by_kind,
AssetStoreModel.last_updated_by_dag_id,
AssetStoreModel.last_updated_by_run_id,
AssetStoreModel.last_updated_by_task_id,
AssetStoreModel.last_updated_by_map_index,
)
.where(AssetStoreModel.asset_id == asset_id)
.order_by(AssetStoreModel.key.asc())
Expand All @@ -87,7 +93,21 @@ def list_asset_store(
)
rows = session.execute(paginated).all()
entries = [
AssetStoreResponse(key=r.key, value=json.loads(r.value), updated_at=r.updated_at) for r in rows
AssetStoreResponse(
key=r.key,
value=json.loads(r.value),
updated_at=r.updated_at,
last_updated_by=AssetStoreLastUpdatedBy(
kind=r.last_updated_by_kind,
dag_id=r.last_updated_by_dag_id,
run_id=r.last_updated_by_run_id,
task_id=r.last_updated_by_task_id,
map_index=r.last_updated_by_map_index,
)
if r.last_updated_by_kind is not None
else None,
)
for r in rows
]
return AssetStoreCollectionResponse(asset_store=entries, total_entries=total_entries)

Expand All @@ -108,6 +128,11 @@ def get_asset_store(
AssetStoreModel.key,
AssetStoreModel.value,
AssetStoreModel.updated_at,
AssetStoreModel.last_updated_by_kind,
AssetStoreModel.last_updated_by_dag_id,
AssetStoreModel.last_updated_by_run_id,
AssetStoreModel.last_updated_by_task_id,
AssetStoreModel.last_updated_by_map_index,
).where(
AssetStoreModel.asset_id == asset_id,
AssetStoreModel.key == key,
Expand All @@ -118,7 +143,20 @@ def get_asset_store(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Asset store key {key!r} not found",
)
return AssetStoreResponse(key=row.key, value=json.loads(row.value), updated_at=row.updated_at)
return AssetStoreResponse(
key=row.key,
value=json.loads(row.value),
updated_at=row.updated_at,
last_updated_by=AssetStoreLastUpdatedBy(
kind=row.last_updated_by_kind,
dag_id=row.last_updated_by_dag_id,
run_id=row.last_updated_by_run_id,
task_id=row.last_updated_by_task_id,
map_index=row.last_updated_by_map_index,
)
if row.last_updated_by_kind is not None
else None,
)


@asset_store_router.put(
Expand All @@ -134,7 +172,13 @@ def set_asset_store(
session: SessionDep,
) -> None:
"""Set an asset store value. Creates or overwrites the key."""
_get_db_backend().set(AssetScope(asset_id=asset_id), key, json.dumps(body.value), session=session)
_get_db_backend().set_asset_store(
AssetScope(asset_id=asset_id),
key,
json.dumps(body.value),
kind=AssetStoreWriterKind.API,
session=session,
)


@asset_store_router.delete(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,39 @@
from fastapi import HTTPException, Query, status
from sqlalchemy import select

from airflow._shared.state import AssetScope
from airflow._shared.state import AssetScope, AssetStoreWriterKind
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.execution_api.datamodels.asset_store import (
AssetStorePutBody,
AssetStoreResponse,
)
from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.security import CurrentTIToken, ExecutionAPIRoute
from airflow.models.asset import AssetModel
from airflow.models.taskinstance import TaskInstance
from airflow.state import get_state_backend
from airflow.state.metastore import MetastoreStoreBackend

_TIWriterFields = tuple[str, str, str, int]


def _fetch_ti_writer_fields(token: TIToken, session: SessionDep) -> _TIWriterFields:
"""Return (dag_id, run_id, task_id, map_index) for the TI identified by the token."""
row = session.execute(
select(
TaskInstance.dag_id,
TaskInstance.run_id,
TaskInstance.task_id,
TaskInstance.map_index,
).where(TaskInstance.id == token.id)
).one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": f"Task instance {token.id!r} not found"},
)
return row.dag_id, row.run_id, row.task_id, row.map_index


# TODO(AIP-103): enforce that the requesting task is registered with the asset
# (via task_inlet_asset_reference or task_outlet_asset_reference) before
Expand Down Expand Up @@ -103,10 +127,27 @@ def set_asset_store_by_name(
key: Annotated[str, Query(min_length=1)],
body: AssetStorePutBody,
session: SessionDep,
token: TIToken = CurrentTIToken,
) -> None:
"""Set an asset store value by asset name."""
asset_id = _resolve_asset_id_by_name(name, session)
get_state_backend().set(AssetScope(asset_id=asset_id), key, json.dumps(body.value), session=session)
backend = get_state_backend()
scope = AssetScope(asset_id=asset_id)
if isinstance(backend, MetastoreStoreBackend):
dag_id, run_id, task_id, map_index = _fetch_ti_writer_fields(token, session)
backend.set_asset_store(
scope,
key,
json.dumps(body.value),
kind=AssetStoreWriterKind.TASK,
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
map_index=map_index,
session=session,
)
else:
backend.set(scope, key, json.dumps(body.value), session=session)


@router.delete("/by-name/value", status_code=status.HTTP_204_NO_CONTENT)
Expand Down Expand Up @@ -153,10 +194,27 @@ def set_asset_store_by_uri(
key: Annotated[str, Query(min_length=1)],
body: AssetStorePutBody,
session: SessionDep,
token: TIToken = CurrentTIToken,
) -> None:
"""Set an asset store value by asset URI."""
asset_id = _resolve_asset_id_by_uri(uri, session)
get_state_backend().set(AssetScope(asset_id=asset_id), key, json.dumps(body.value), session=session)
backend = get_state_backend()
scope = AssetScope(asset_id=asset_id)
if isinstance(backend, MetastoreStoreBackend):
dag_id, run_id, task_id, map_index = _fetch_ti_writer_fields(token, session)
backend.set_asset_store(
scope,
key,
json.dumps(body.value),
kind=AssetStoreWriterKind.TASK,
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
map_index=map_index,
session=session,
)
else:
backend.set(scope, key, json.dumps(body.value), session=session)


@router.delete("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def upgrade():
sa.Column("key", sa.String(length=512), nullable=False),
sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), nullable=False),
sa.Column("updated_at", UtcDateTime(), nullable=False),
sa.Column("last_updated_by_kind", sa.String(length=16), nullable=True),
sa.Column("last_updated_by_dag_id", StringID(), nullable=True),
sa.Column("last_updated_by_run_id", StringID(), nullable=True),
sa.Column("last_updated_by_task_id", StringID(), nullable=True),
sa.Column("last_updated_by_map_index", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["asset_id"], ["asset.id"], name="asset_store_asset_fkey", ondelete="CASCADE"
),
Expand Down
10 changes: 10 additions & 0 deletions airflow-core/src/airflow/models/asset_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class AssetStoreModel(Base):

Not scoped to any DAG run — a watermark written in run 1 is readable by run 2.
Rows survive until explicitly deleted or the asset itself is deleted.

``last_updated_by_*`` columns record who last wrote this entry. They are denormalized
(no FK) so that the references survives DAG run cleanup, and so cases like watchers (``BaseEventTrigger``)
can write without a task instance.
"""

__tablename__ = "asset_store"
Expand All @@ -43,6 +47,12 @@ class AssetStoreModel(Base):
value: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT, "mysql"), nullable=False)
updated_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False)

last_updated_by_kind: Mapped[str | None] = mapped_column(String(16), nullable=True)
last_updated_by_dag_id: Mapped[str | None] = mapped_column(String(250, **COLLATION_ARGS), nullable=True)
last_updated_by_run_id: Mapped[str | None] = mapped_column(String(250, **COLLATION_ARGS), nullable=True)
last_updated_by_task_id: Mapped[str | None] = mapped_column(String(250, **COLLATION_ARGS), nullable=True)
last_updated_by_map_index: Mapped[int | None] = mapped_column(Integer, nullable=True)

__table_args__ = (
PrimaryKeyConstraint("asset_id", "key", name="asset_store_pkey"),
ForeignKeyConstraint(
Expand Down
Loading
Loading