Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
# under the License.
from __future__ import annotations

import json
from datetime import datetime

from pydantic import Field
from pydantic import JsonValue, field_validator

from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel

_MAX_SERIALIZED_BYTES = 65535


class AssetStateResponse(BaseModel):
"""A single asset state key/value pair with metadata."""

key: str
value: str
value: JsonValue
updated_at: datetime


Expand All @@ -41,4 +44,17 @@ class AssetStateCollectionResponse(BaseModel):
class AssetStateBody(StrictBaseModel):
"""Request body for setting an asset state value."""

value: str = Field(max_length=65535)
value: JsonValue

@field_validator("value")
@classmethod
def value_is_json_representable(cls, v: JsonValue) -> JsonValue:
if v is None:
raise ValueError("value cannot be null")
try:
serialized = json.dumps(v, allow_nan=False)
except ValueError:
raise ValueError("value contains non-finite numbers; NaN and Inf are not JSON representable")
if len(serialized) > _MAX_SERIALIZED_BYTES:
raise ValueError(f"value exceeds maximum serialized size of {_MAX_SERIALIZED_BYTES} bytes")
return v
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
# under the License.
from __future__ import annotations

import json
from datetime import datetime

from pydantic import Field
from pydantic import JsonValue, field_validator

from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel

_MAX_SERIALIZED_BYTES = 65535


class TaskStateResponse(BaseModel):
"""A single task state key/value pair with metadata."""

key: str
value: str
value: JsonValue
updated_at: datetime
expires_at: datetime | None

Expand All @@ -42,4 +45,17 @@ class TaskStateCollectionResponse(BaseModel):
class TaskStateBody(StrictBaseModel):
"""Request body for setting a task state value."""

value: str = Field(max_length=65535)
value: JsonValue

@field_validator("value")
@classmethod
def value_is_json_representable(cls, v: JsonValue) -> JsonValue:
if v is None:
raise ValueError("value cannot be null")
try:
serialized = json.dumps(v, allow_nan=False)
except ValueError:
raise ValueError("value contains non-finite numbers; NaN and Inf are not JSON representable")
if len(serialized) > _MAX_SERIALIZED_BYTES:
raise ValueError(f"value exceeds maximum serialized size of {_MAX_SERIALIZED_BYTES} bytes")
return v
Original file line number Diff line number Diff line change
Expand Up @@ -11360,9 +11360,7 @@ components:
AssetStateBody:
properties:
value:
type: string
maxLength: 65535
title: Value
$ref: '#/components/schemas/JsonValue'
additionalProperties: false
type: object
required:
Expand Down Expand Up @@ -11391,8 +11389,7 @@ components:
type: string
title: Key
value:
type: string
title: Value
$ref: '#/components/schemas/JsonValue'
updated_at:
type: string
format: date-time
Expand Down Expand Up @@ -15810,9 +15807,7 @@ components:
TaskStateBody:
properties:
value:
type: string
maxLength: 65535
title: Value
$ref: '#/components/schemas/JsonValue'
additionalProperties: false
type: object
required:
Expand Down Expand Up @@ -15841,8 +15836,7 @@ components:
type: string
title: Key
value:
type: string
title: Value
$ref: '#/components/schemas/JsonValue'
updated_at:
type: string
format: date-time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import json
from typing import Annotated

from fastapi import Depends, HTTPException, status
Expand Down Expand Up @@ -85,7 +86,9 @@ def list_asset_states(
session=session,
)
rows = session.execute(paginated).all()
entries = [AssetStateResponse(key=r.key, value=r.value, updated_at=r.updated_at) for r in rows]
entries = [
AssetStateResponse(key=r.key, value=json.loads(r.value), updated_at=r.updated_at) for r in rows
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same json.loads 500 risk noted on task_state.py:92 -- one bad row in the asset's state table will 500 the whole list endpoint. Worth defensive decoding here, especially since asset state is meant to be a long-lived watermark (so legacy rows from pre-#67418 are more likely to still be around than transient task state).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as task state — both write paths always store json.dumps(...), so no non-JSON rows can exist. The long-lived watermark concern doesn't apply since 3.3 hasn't shipped yet.

]
return AssetStateCollectionResponse(asset_states=entries, total_entries=total_entries)


Expand Down Expand Up @@ -115,7 +118,7 @@ def get_asset_state(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Asset state key {key!r} not found",
)
return AssetStateResponse(key=row.key, value=row.value, updated_at=row.updated_at)
return AssetStateResponse(key=row.key, value=json.loads(row.value), updated_at=row.updated_at)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same json.loads 500 risk as the list endpoint (line 90). A legacy or custom-backend row turns GET into a 500.

Copy link
Copy Markdown
Contributor Author

@amoghrajesh amoghrajesh May 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above — no non-JSON rows can exist since both write paths always do json.dumps(...) before storing.



@asset_state_router.put(
Expand All @@ -131,7 +134,7 @@ def set_asset_state(
session: SessionDep,
) -> None:
"""Set an asset state value. Creates or overwrites the key."""
get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, session=session)
get_state_backend().set(AssetScope(asset_id=asset_id), key, json.dumps(body.value), session=session)


@asset_state_router.delete(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import json
from typing import Annotated

from fastapi import Depends, HTTPException, Query, status
Expand Down Expand Up @@ -87,7 +88,9 @@ def list_task_states(
)
rows = session.execute(paginated).all()
entries = [
TaskStateResponse(key=r.key, value=r.value, updated_at=r.updated_at, expires_at=r.expires_at)
TaskStateResponse(
key=r.key, value=json.loads(r.value), updated_at=r.updated_at, expires_at=r.expires_at
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

json.loads(r.value) will raise JSONDecodeError on any row whose value column is not valid JSON, and FastAPI surfaces that as a generic 500. Two ways this hits prod:

  1. Rows written by the execution API between AIP-103: Add Execution API endpoints for task and asset states #66073 (2026-05-04) and Simplifing authoring of task and asset states by allowing JSON types #67418 (2026-05-25) stored raw strings; Simplifing authoring of task and asset states by allowing JSON types #67418 switched to json.dumps(...) but didn't migrate existing rows. Anyone who ran on a pre-Simplifing authoring of task and asset states by allowing JSON types #67418 build now poisons reads.
  2. The BaseStateBackend interface only requires set(scope, key, value: str, ...). A custom backend that stores its own value format (e.g., escaped/quoted, msgpack-decoded-to-str, etc.) would have worked under the old value: str contract and now breaks.

In a list endpoint this is worse than in get: one bad row poisons the whole page, so users can't even paginate past it.

Suggest wrapping the decode in a try/except (skip+log, or return a sentinel) so a single legacy/odd row doesn't 500 the whole listing. Same pattern needed at get_task_state (line 133), and in asset_state.py at lines 90 and 121.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.3 hasn't shipped, so the pre-#67418 row scenario cannot exist in any deployed cluster.

For the custom backend concern: the core API currently calls get_state_backend(), but that's a known issue I already flagged in the backlog and will be fixing in next PRs, ie: the core API routes should be DB-direct (same as XCom), never routing through a custom backend. Once that is fixed, the core API only ever reads rows written by the execution API into the database, either direct data or references, which always stores json.dumps(value), so json.loads on read is always safe. The JSONDecodeError risk goes away entirely at that point.

Copy link
Copy Markdown
Contributor Author

@amoghrajesh amoghrajesh May 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both write paths (task execution and core rest API) guarantee the DB always contains valid JSON: the execution API does json.dumps(value) before storing, and the Core API PUT does json.dumps(body.value) before storing.json.loads on the read path is always safe.

)
for r in rows
]
return TaskStateCollectionResponse(task_states=entries, total_entries=total_entries)
Expand Down Expand Up @@ -127,7 +130,7 @@ def get_task_state(
detail=f"Task state key {key!r} not found",
)
return TaskStateResponse(
key=row.key, value=row.value, updated_at=row.updated_at, expires_at=row.expires_at
key=row.key, value=json.loads(row.value), updated_at=row.updated_at, expires_at=row.expires_at
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same json.loads 500 risk as the list endpoint above (see comment on line 92). A row written by an execution-API build between #66073 and #67418, or by a custom BaseStateBackend, will turn a GET into a 500 instead of returning the value. Worth catching JSONDecodeError explicitly here too.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. No such rows can exist since both write paths (execution API and Core API PUT) always store json.dumps(...) before writing to the DB. json.loads on the read path is always safe.

)


Expand Down Expand Up @@ -162,7 +165,7 @@ def set_task_state(
)
scope = _get_scope(dag_id, dag_run_id, task_id, map_index)
try:
get_state_backend().set(scope, key, body.value, session=session)
get_state_backend().set(scope, key, json.dumps(body.value), session=session)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e

Expand Down
14 changes: 4 additions & 10 deletions airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,7 @@ export const $AssetResponse = {
export const $AssetStateBody = {
properties: {
value: {
type: 'string',
maxLength: 65535,
title: 'Value'
'$ref': '#/components/schemas/JsonValue'
}
},
additionalProperties: false,
Expand Down Expand Up @@ -423,8 +421,7 @@ export const $AssetStateResponse = {
title: 'Key'
},
value: {
type: 'string',
title: 'Value'
'$ref': '#/components/schemas/JsonValue'
},
updated_at: {
type: 'string',
Expand Down Expand Up @@ -6982,9 +6979,7 @@ export const $TaskResponse = {
export const $TaskStateBody = {
properties: {
value: {
type: 'string',
maxLength: 65535,
title: 'Value'
'$ref': '#/components/schemas/JsonValue'
}
},
additionalProperties: false,
Expand Down Expand Up @@ -7021,8 +7016,7 @@ export const $TaskStateResponse = {
title: 'Key'
},
value: {
type: 'string',
title: 'Value'
'$ref': '#/components/schemas/JsonValue'
},
updated_at: {
type: 'string',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ export type AssetResponse = {
* Request body for setting an asset state value.
*/
export type AssetStateBody = {
value: string;
value: JsonValue;
};

/**
Expand All @@ -116,7 +116,7 @@ export type AssetStateCollectionResponse = {
*/
export type AssetStateResponse = {
key: string;
value: string;
value: JsonValue;
updated_at: string;
};

Expand Down Expand Up @@ -1715,7 +1715,7 @@ export type TaskResponse = {
* Request body for setting a task state value.
*/
export type TaskStateBody = {
value: string;
value: JsonValue;
};

/**
Expand All @@ -1731,7 +1731,7 @@ export type TaskStateCollectionResponse = {
*/
export type TaskStateResponse = {
key: string;
value: string;
value: JsonValue;
updated_at: string;
expires_at: string | null;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
# under the License.
from __future__ import annotations

import json

import pytest
from pydantic import ValidationError
from sqlalchemy import select

from airflow.api_fastapi.core_api.datamodels.asset_state import AssetStateBody
from airflow.models.asset import AssetModel
from airflow.models.asset_state import AssetStateModel

Expand All @@ -37,7 +42,7 @@ def _create_asset(session) -> AssetModel:


def _create_asset_state(session, asset_id: int, key: str, value: str) -> None:
row = AssetStateModel(asset_id=asset_id, key=key, value=value)
row = AssetStateModel(asset_id=asset_id, key=key, value=json.dumps(value))
session.add(row)
session.flush()

Expand Down Expand Up @@ -172,9 +177,59 @@ def test_overwrites_existing_key(self, test_client):
def test_empty_body_returns_422(self, test_client):
assert test_client.put(f"{self._base_url}/watermark", json={}).status_code == 422

def test_null_value_returns_422(self, test_client):
assert test_client.put(f"{self._base_url}/watermark", json={"value": None}).status_code == 422

@pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")])
def test_non_finite_float_returns_422(self, test_client, bad_float):
response = test_client.put(
f"{self._base_url}/watermark",
content=json.dumps({"value": bad_float}, allow_nan=True).encode(),
headers={"Content-Type": "application/json"},
)
assert response.status_code == 422

def test_oversized_value_returns_422(self, test_client):
assert test_client.put(f"{self._base_url}/watermark", json={"value": "x" * 65536}).status_code == 422

@pytest.mark.parametrize("bad_value", [float("nan"), float("inf"), {"a": float("nan")}, [float("inf")]])
def test_non_finite_float_rejected_by_validator(self, bad_value):
with pytest.raises(ValidationError, match="non-finite"):
AssetStateBody(value=bad_value)

@pytest.mark.parametrize(
("value", "expected_db"),
[
(42, "42"),
("hello", '"hello"'),
({"k": 1}, '{"k": 1}'),
([1, 2], "[1, 2]"),
],
)
def test_put_stores_json_encoded_value(self, test_client, value, expected_db):
test_client.put(f"{self._base_url}/k", json={"value": value})
row = self._session.scalar(
select(AssetStateModel).where(
AssetStateModel.asset_id == self.asset.id,
AssetStateModel.key == "k",
)
)
assert row is not None
assert row.value == expected_db

@pytest.mark.parametrize("value", [42, True, {"rows": 100}, [1, "two"], "hello"])
def test_core_api_write_read_roundtrip(self, test_client, value):
"""Core API write then Core API read returns the same native value."""
test_client.put(f"{self._base_url}/k", json={"value": value})
assert test_client.get(f"{self._base_url}/k").json()["value"] == value

@pytest.mark.parametrize("value", [42, True, {"rows": 100}, [1, "two"], "hello"])
def test_worker_write_core_api_read_roundtrip(self, test_client, value):
"""Worker write (json.dumps in DB) then Core API read returns native value."""
_create_asset_state(self._session, self.asset.id, "k", value)
self._session.commit()
assert test_client.get(f"{self._base_url}/k").json()["value"] == value

def test_key_with_slash_is_supported(self, test_client):
response = test_client.put(f"{self._base_url}/partition/date", json={"value": "2026-05-01"})
assert response.status_code == 204
Expand Down
Loading
Loading