From 624c2e35c3259a457b16930d6141e2849977b2c7 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 26 May 2026 16:11:55 +0530 Subject: [PATCH 1/2] Unify task/asset state storage between Core API and Execution API --- .../core_api/datamodels/asset_state.py | 21 ++++++- .../core_api/datamodels/task_state.py | 21 ++++++- .../openapi/v2-rest-api-generated.yaml | 14 ++--- .../core_api/routes/public/asset_state.py | 9 ++- .../core_api/routes/public/task_state.py | 9 ++- .../ui/openapi-gen/requests/schemas.gen.ts | 14 ++--- .../ui/openapi-gen/requests/types.gen.ts | 8 +-- .../routes/public/test_asset_state.py | 50 ++++++++++++++++- .../core_api/routes/public/test_task_state.py | 55 ++++++++++++++++++- .../airflowctl/api/datamodels/generated.py | 54 +++++++++--------- 10 files changed, 188 insertions(+), 67 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/asset_state.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/asset_state.py index 379d780290938..6aff19a538535 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/asset_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/asset_state.py @@ -16,18 +16,22 @@ # under the License. from __future__ import annotations +import json +import math from datetime import datetime -from pydantic import Field +from pydantic import JsonValue, field_validator from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel +_MAX_SERIALIZED_BYTES = 65535 + class AssetStateResponse(BaseModel): """A single asset state key/value pair with metadata.""" key: str - value: str + value: JsonValue updated_at: datetime @@ -41,4 +45,15 @@ class AssetStateCollectionResponse(BaseModel): class AssetStateBody(StrictBaseModel): """Request body for setting an asset state value.""" - value: str = Field(max_length=65535) + value: JsonValue + + @field_validator("value") + @classmethod + def value_is_json_representable(cls, v: JsonValue) -> JsonValue: + if v is None: + raise ValueError("value cannot be null") + if isinstance(v, float) and not math.isfinite(v): + raise ValueError("value must be a finite number; NaN and Inf are not JSON representable") + if len(json.dumps(v)) > _MAX_SERIALIZED_BYTES: + raise ValueError(f"value exceeds maximum serialized size of {_MAX_SERIALIZED_BYTES} bytes") + return v diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py index 856de74a0877b..d289e29ae532e 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py @@ -16,18 +16,22 @@ # under the License. from __future__ import annotations +import json +import math from datetime import datetime -from pydantic import Field +from pydantic import JsonValue, field_validator from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel +_MAX_SERIALIZED_BYTES = 65535 + class TaskStateResponse(BaseModel): """A single task state key/value pair with metadata.""" key: str - value: str + value: JsonValue updated_at: datetime expires_at: datetime | None @@ -42,4 +46,15 @@ class TaskStateCollectionResponse(BaseModel): class TaskStateBody(StrictBaseModel): """Request body for setting a task state value.""" - value: str = Field(max_length=65535) + value: JsonValue + + @field_validator("value") + @classmethod + def value_is_json_representable(cls, v: JsonValue) -> JsonValue: + if v is None: + raise ValueError("value cannot be null") + if isinstance(v, float) and not math.isfinite(v): + raise ValueError("value must be a finite number; NaN and Inf are not JSON representable") + if len(json.dumps(v)) > _MAX_SERIALIZED_BYTES: + raise ValueError(f"value exceeds maximum serialized size of {_MAX_SERIALIZED_BYTES} bytes") + return v diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 23f2622b13b3f..0996d8a37fc3f 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -11360,9 +11360,7 @@ components: AssetStateBody: properties: value: - type: string - maxLength: 65535 - title: Value + $ref: '#/components/schemas/JsonValue' additionalProperties: false type: object required: @@ -11391,8 +11389,7 @@ components: type: string title: Key value: - type: string - title: Value + $ref: '#/components/schemas/JsonValue' updated_at: type: string format: date-time @@ -15810,9 +15807,7 @@ components: TaskStateBody: properties: value: - type: string - maxLength: 65535 - title: Value + $ref: '#/components/schemas/JsonValue' additionalProperties: false type: object required: @@ -15841,8 +15836,7 @@ components: type: string title: Key value: - type: string - title: Value + $ref: '#/components/schemas/JsonValue' updated_at: type: string format: date-time diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/asset_state.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/asset_state.py index 43580e0c762c7..877bd5f8f44ee 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/asset_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/asset_state.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import json from typing import Annotated from fastapi import Depends, HTTPException, status @@ -85,7 +86,9 @@ def list_asset_states( session=session, ) rows = session.execute(paginated).all() - entries = [AssetStateResponse(key=r.key, value=r.value, updated_at=r.updated_at) for r in rows] + entries = [ + AssetStateResponse(key=r.key, value=json.loads(r.value), updated_at=r.updated_at) for r in rows + ] return AssetStateCollectionResponse(asset_states=entries, total_entries=total_entries) @@ -115,7 +118,7 @@ def get_asset_state( status_code=status.HTTP_404_NOT_FOUND, detail=f"Asset state key {key!r} not found", ) - return AssetStateResponse(key=row.key, value=row.value, updated_at=row.updated_at) + return AssetStateResponse(key=row.key, value=json.loads(row.value), updated_at=row.updated_at) @asset_state_router.put( @@ -131,7 +134,7 @@ def set_asset_state( session: SessionDep, ) -> None: """Set an asset state value. Creates or overwrites the key.""" - get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, session=session) + get_state_backend().set(AssetScope(asset_id=asset_id), key, json.dumps(body.value), session=session) @asset_state_router.delete( diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py index 138380232a8aa..31cc7272ddca6 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import json from typing import Annotated from fastapi import Depends, HTTPException, Query, status @@ -87,7 +88,9 @@ def list_task_states( ) rows = session.execute(paginated).all() entries = [ - TaskStateResponse(key=r.key, value=r.value, updated_at=r.updated_at, expires_at=r.expires_at) + TaskStateResponse( + key=r.key, value=json.loads(r.value), updated_at=r.updated_at, expires_at=r.expires_at + ) for r in rows ] return TaskStateCollectionResponse(task_states=entries, total_entries=total_entries) @@ -127,7 +130,7 @@ def get_task_state( detail=f"Task state key {key!r} not found", ) return TaskStateResponse( - key=row.key, value=row.value, updated_at=row.updated_at, expires_at=row.expires_at + key=row.key, value=json.loads(row.value), updated_at=row.updated_at, expires_at=row.expires_at ) @@ -162,7 +165,7 @@ def set_task_state( ) scope = _get_scope(dag_id, dag_run_id, task_id, map_index) try: - get_state_backend().set(scope, key, body.value, session=session) + get_state_backend().set(scope, key, json.dumps(body.value), session=session) except ValueError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index 4a4b95f183023..5ec94d6388176 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -384,9 +384,7 @@ export const $AssetResponse = { export const $AssetStateBody = { properties: { value: { - type: 'string', - maxLength: 65535, - title: 'Value' + '$ref': '#/components/schemas/JsonValue' } }, additionalProperties: false, @@ -423,8 +421,7 @@ export const $AssetStateResponse = { title: 'Key' }, value: { - type: 'string', - title: 'Value' + '$ref': '#/components/schemas/JsonValue' }, updated_at: { type: 'string', @@ -6982,9 +6979,7 @@ export const $TaskResponse = { export const $TaskStateBody = { properties: { value: { - type: 'string', - maxLength: 65535, - title: 'Value' + '$ref': '#/components/schemas/JsonValue' } }, additionalProperties: false, @@ -7021,8 +7016,7 @@ export const $TaskStateResponse = { title: 'Key' }, value: { - type: 'string', - title: 'Value' + '$ref': '#/components/schemas/JsonValue' }, updated_at: { type: 'string', diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 70c9fa131c191..cd0060ca4d4ec 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -100,7 +100,7 @@ export type AssetResponse = { * Request body for setting an asset state value. */ export type AssetStateBody = { - value: string; + value: JsonValue; }; /** @@ -116,7 +116,7 @@ export type AssetStateCollectionResponse = { */ export type AssetStateResponse = { key: string; - value: string; + value: JsonValue; updated_at: string; }; @@ -1715,7 +1715,7 @@ export type TaskResponse = { * Request body for setting a task state value. */ export type TaskStateBody = { - value: string; + value: JsonValue; }; /** @@ -1731,7 +1731,7 @@ export type TaskStateCollectionResponse = { */ export type TaskStateResponse = { key: string; - value: string; + value: JsonValue; updated_at: string; expires_at: string | null; }; diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_asset_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_asset_state.py index c53fbb99ee91b..3138973aafcb1 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_asset_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_asset_state.py @@ -16,7 +16,10 @@ # under the License. from __future__ import annotations +import json + import pytest +from sqlalchemy import select from airflow.models.asset import AssetModel from airflow.models.asset_state import AssetStateModel @@ -37,7 +40,7 @@ def _create_asset(session) -> AssetModel: def _create_asset_state(session, asset_id: int, key: str, value: str) -> None: - row = AssetStateModel(asset_id=asset_id, key=key, value=value) + row = AssetStateModel(asset_id=asset_id, key=key, value=json.dumps(value)) session.add(row) session.flush() @@ -172,9 +175,54 @@ def test_overwrites_existing_key(self, test_client): def test_empty_body_returns_422(self, test_client): assert test_client.put(f"{self._base_url}/watermark", json={}).status_code == 422 + def test_null_value_returns_422(self, test_client): + assert test_client.put(f"{self._base_url}/watermark", json={"value": None}).status_code == 422 + + @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")]) + def test_non_finite_float_returns_422(self, test_client, bad_float): + with pytest.raises(ValueError, match="Out of range float values are not JSON compliant"): + test_client.put( + f"{self._base_url}/watermark", + content=json.dumps({"value": bad_float}, allow_nan=True).encode(), + headers={"Content-Type": "application/json"}, + ) + def test_oversized_value_returns_422(self, test_client): assert test_client.put(f"{self._base_url}/watermark", json={"value": "x" * 65536}).status_code == 422 + @pytest.mark.parametrize( + ("value", "expected_db"), + [ + (42, "42"), + ("hello", '"hello"'), + ({"k": 1}, '{"k": 1}'), + ([1, 2], "[1, 2]"), + ], + ) + def test_put_stores_json_encoded_value(self, test_client, value, expected_db): + test_client.put(f"{self._base_url}/k", json={"value": value}) + row = self._session.scalar( + select(AssetStateModel).where( + AssetStateModel.asset_id == self.asset.id, + AssetStateModel.key == "k", + ) + ) + assert row is not None + assert row.value == expected_db + + @pytest.mark.parametrize("value", [42, True, {"rows": 100}, [1, "two"], "hello"]) + def test_core_api_write_read_roundtrip(self, test_client, value): + """Core API write then Core API read returns the same native value.""" + test_client.put(f"{self._base_url}/k", json={"value": value}) + assert test_client.get(f"{self._base_url}/k").json()["value"] == value + + @pytest.mark.parametrize("value", [42, True, {"rows": 100}, [1, "two"], "hello"]) + def test_worker_write_core_api_read_roundtrip(self, test_client, value): + """Worker write (json.dumps in DB) then Core API read returns native value.""" + _create_asset_state(self._session, self.asset.id, "k", value) + self._session.commit() + assert test_client.get(f"{self._base_url}/k").json()["value"] == value + def test_key_with_slash_is_supported(self, test_client): response = test_client.put(f"{self._base_url}/partition/date", json={"value": "2026-05-01"}) assert response.status_code == 204 diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py index c53212fa5e444..ae91f000b06a6 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import json + import pytest from sqlalchemy import select @@ -54,7 +56,7 @@ def _create_task_state(session, key: str, value: str, dag_run: DagRun) -> None: task_id=TASK_ID, map_index=-1, key=key, - value=value, + value=json.dumps(value), ) session.add(row) session.flush() @@ -114,7 +116,7 @@ def test_map_index_isolation(self, test_client): task_id=TASK_ID, map_index=0, key="job_id", - value="mapped_app", + value=json.dumps("mapped_app"), ) self._session.add(row) self._session.commit() @@ -191,6 +193,18 @@ def test_overwrites_existing_key(self, test_client): def test_empty_body_returns_422(self, test_client): assert test_client.put(f"{BASE_URL}/job_id", json={}).status_code == 422 + def test_null_value_returns_422(self, test_client): + assert test_client.put(f"{BASE_URL}/job_id", json={"value": None}).status_code == 422 + + @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")]) + def test_non_finite_float_returns_422(self, test_client, bad_float): + with pytest.raises(ValueError, match="Out of range float values are not JSON compliant"): + test_client.put( + f"{BASE_URL}/job_id", + content=json.dumps({"value": bad_float}, allow_nan=True).encode(), + headers={"Content-Type": "application/json"}, + ) + def test_oversized_value_returns_422(self, test_client): assert test_client.put(f"{BASE_URL}/job_id", json={"value": "x" * 65536}).status_code == 422 @@ -206,6 +220,41 @@ def test_set_nonexistent_task_id_returns_404(self, test_client): response = test_client.put(bad_url, json={"value": "v"}) assert response.status_code == 404 + @pytest.mark.parametrize( + ("value", "expected_db"), + [ + (42, "42"), + ("hello", '"hello"'), + ({"k": 1}, '{"k": 1}'), + ([1, 2], "[1, 2]"), + ], + ) + def test_put_stores_json_encoded_value(self, test_client, value, expected_db): + test_client.put(f"{BASE_URL}/k", json={"value": value}) + row = self._session.scalar( + select(TaskStateModel).where( + TaskStateModel.dag_id == DAG_ID, + TaskStateModel.run_id == RUN_ID, + TaskStateModel.task_id == TASK_ID, + TaskStateModel.key == "k", + ) + ) + assert row is not None + assert row.value == expected_db + + @pytest.mark.parametrize("value", [42, True, {"rows": 100}, [1, "two"], "hello"]) + def test_core_api_write_read_roundtrip(self, test_client, value): + """Core API write then Core API read returns the same native value.""" + test_client.put(f"{BASE_URL}/k", json={"value": value}) + assert test_client.get(f"{BASE_URL}/k").json()["value"] == value + + @pytest.mark.parametrize("value", [42, True, {"rows": 100}, [1, "two"], "hello"]) + def test_worker_write_core_api_read_roundtrip(self, test_client, value): + """Worker write (json.dumps in DB) then Core API read returns native value.""" + _create_task_state(self._session, "k", value, self.dag_run) + self._session.commit() + assert test_client.get(f"{BASE_URL}/k").json()["value"] == value + def test_key_with_slash_is_supported(self, test_client): response = test_client.put(f"{BASE_URL}/workflow/step_1", json={"value": "v"}) assert response.status_code == 204 @@ -266,7 +315,7 @@ def test_all_map_indices_clears_across_mapped_instances(self, test_client): task_id=TASK_ID, map_index=map_index, key="job_id", - value=f"app_{map_index}", + value=json.dumps(f"app_{map_index}"), ) self._session.add(row) self._session.commit() diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index f05fa65cf56f8..44ac8d3938751 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -49,27 +49,6 @@ class AssetAliasResponse(BaseModel): group: Annotated[str, Field(title="Group")] -class AssetStateBody(BaseModel): - """ - Request body for setting an asset state value. - """ - - model_config = ConfigDict( - extra="forbid", - ) - value: Annotated[str, Field(max_length=65535, title="Value")] - - -class AssetStateResponse(BaseModel): - """ - A single asset state key/value pair with metadata. - """ - - key: Annotated[str, Field(title="Key")] - value: Annotated[str, Field(title="Value")] - updated_at: Annotated[datetime, Field(title="Updated At")] - - class AssetWatcherResponse(BaseModel): """ Asset watcher serializer for responses. @@ -981,7 +960,7 @@ class TaskStateBody(BaseModel): model_config = ConfigDict( extra="forbid", ) - value: Annotated[str, Field(max_length=65535, title="Value")] + value: JsonValue class TaskStateResponse(BaseModel): @@ -990,7 +969,7 @@ class TaskStateResponse(BaseModel): """ key: Annotated[str, Field(title="Key")] - value: Annotated[str, Field(title="Value")] + value: JsonValue updated_at: Annotated[datetime, Field(title="Updated At")] expires_at: Annotated[datetime | None, Field(title="Expires At")] = None @@ -1225,13 +1204,25 @@ class AssetResponse(BaseModel): last_asset_event: LastAssetEventResponse | None = None -class AssetStateCollectionResponse(BaseModel): +class AssetStateBody(BaseModel): """ - All asset state entries for an asset. + Request body for setting an asset state value. """ - asset_states: Annotated[list[AssetStateResponse], Field(title="Asset States")] - total_entries: Annotated[int, Field(title="Total Entries")] + model_config = ConfigDict( + extra="forbid", + ) + value: JsonValue + + +class AssetStateResponse(BaseModel): + """ + A single asset state key/value pair with metadata. + """ + + key: Annotated[str, Field(title="Key")] + value: JsonValue + updated_at: Annotated[datetime, Field(title="Updated At")] class BackfillPostBody(BaseModel): @@ -2017,6 +2008,15 @@ class AssetEventCollectionResponse(BaseModel): total_entries: Annotated[int, Field(title="Total Entries")] +class AssetStateCollectionResponse(BaseModel): + """ + All asset state entries for an asset. + """ + + asset_states: Annotated[list[AssetStateResponse], Field(title="Asset States")] + total_entries: Annotated[int, Field(title="Total Entries")] + + class BackfillCollectionResponse(BaseModel): """ Backfill Collection serializer for responses. From 1e8e0ea38ce60a7fdd86f1009723cc9e5879cdc6 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 27 May 2026 12:52:23 +0530 Subject: [PATCH 2/2] review comments from kaxil --- .../core_api/datamodels/asset_state.py | 9 +++++---- .../core_api/datamodels/task_state.py | 9 +++++---- .../routes/public/test_asset_state.py | 19 +++++++++++++------ .../core_api/routes/public/test_task_state.py | 19 +++++++++++++------ 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/asset_state.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/asset_state.py index 6aff19a538535..a373fe89d2d75 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/asset_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/asset_state.py @@ -17,7 +17,6 @@ from __future__ import annotations import json -import math from datetime import datetime from pydantic import JsonValue, field_validator @@ -52,8 +51,10 @@ class AssetStateBody(StrictBaseModel): def value_is_json_representable(cls, v: JsonValue) -> JsonValue: if v is None: raise ValueError("value cannot be null") - if isinstance(v, float) and not math.isfinite(v): - raise ValueError("value must be a finite number; NaN and Inf are not JSON representable") - if len(json.dumps(v)) > _MAX_SERIALIZED_BYTES: + try: + serialized = json.dumps(v, allow_nan=False) + except ValueError: + raise ValueError("value contains non-finite numbers; NaN and Inf are not JSON representable") + if len(serialized) > _MAX_SERIALIZED_BYTES: raise ValueError(f"value exceeds maximum serialized size of {_MAX_SERIALIZED_BYTES} bytes") return v diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py index d289e29ae532e..e6622f842e116 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py @@ -17,7 +17,6 @@ from __future__ import annotations import json -import math from datetime import datetime from pydantic import JsonValue, field_validator @@ -53,8 +52,10 @@ class TaskStateBody(StrictBaseModel): def value_is_json_representable(cls, v: JsonValue) -> JsonValue: if v is None: raise ValueError("value cannot be null") - if isinstance(v, float) and not math.isfinite(v): - raise ValueError("value must be a finite number; NaN and Inf are not JSON representable") - if len(json.dumps(v)) > _MAX_SERIALIZED_BYTES: + try: + serialized = json.dumps(v, allow_nan=False) + except ValueError: + raise ValueError("value contains non-finite numbers; NaN and Inf are not JSON representable") + if len(serialized) > _MAX_SERIALIZED_BYTES: raise ValueError(f"value exceeds maximum serialized size of {_MAX_SERIALIZED_BYTES} bytes") return v diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_asset_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_asset_state.py index 3138973aafcb1..28b5261ceab0b 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_asset_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_asset_state.py @@ -19,8 +19,10 @@ import json import pytest +from pydantic import ValidationError from sqlalchemy import select +from airflow.api_fastapi.core_api.datamodels.asset_state import AssetStateBody from airflow.models.asset import AssetModel from airflow.models.asset_state import AssetStateModel @@ -180,16 +182,21 @@ def test_null_value_returns_422(self, test_client): @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")]) def test_non_finite_float_returns_422(self, test_client, bad_float): - with pytest.raises(ValueError, match="Out of range float values are not JSON compliant"): - test_client.put( - f"{self._base_url}/watermark", - content=json.dumps({"value": bad_float}, allow_nan=True).encode(), - headers={"Content-Type": "application/json"}, - ) + response = test_client.put( + f"{self._base_url}/watermark", + content=json.dumps({"value": bad_float}, allow_nan=True).encode(), + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 422 def test_oversized_value_returns_422(self, test_client): assert test_client.put(f"{self._base_url}/watermark", json={"value": "x" * 65536}).status_code == 422 + @pytest.mark.parametrize("bad_value", [float("nan"), float("inf"), {"a": float("nan")}, [float("inf")]]) + def test_non_finite_float_rejected_by_validator(self, bad_value): + with pytest.raises(ValidationError, match="non-finite"): + AssetStateBody(value=bad_value) + @pytest.mark.parametrize( ("value", "expected_db"), [ diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py index ae91f000b06a6..a0f5e48064797 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py @@ -19,9 +19,11 @@ import json import pytest +from pydantic import ValidationError from sqlalchemy import select from airflow._shared.timezones import timezone +from airflow.api_fastapi.core_api.datamodels.task_state import TaskStateBody from airflow.models.dagrun import DagRun from airflow.models.task_state import TaskStateModel from airflow.providers.standard.operators.empty import EmptyOperator @@ -198,16 +200,21 @@ def test_null_value_returns_422(self, test_client): @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")]) def test_non_finite_float_returns_422(self, test_client, bad_float): - with pytest.raises(ValueError, match="Out of range float values are not JSON compliant"): - test_client.put( - f"{BASE_URL}/job_id", - content=json.dumps({"value": bad_float}, allow_nan=True).encode(), - headers={"Content-Type": "application/json"}, - ) + response = test_client.put( + f"{BASE_URL}/job_id", + content=json.dumps({"value": bad_float}, allow_nan=True).encode(), + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 422 def test_oversized_value_returns_422(self, test_client): assert test_client.put(f"{BASE_URL}/job_id", json={"value": "x" * 65536}).status_code == 422 + @pytest.mark.parametrize("bad_value", [float("nan"), float("inf"), {"a": float("nan")}, [float("inf")]]) + def test_non_finite_float_rejected_by_validator(self, bad_value): + with pytest.raises(ValidationError, match="non-finite"): + TaskStateBody(value=bad_value) + def test_set_nonexistent_dag_run_returns_404(self, test_client): """set() raises ValueError when DagRun doesn't exist — should surface as 404.""" bad_url = f"/dags/{DAG_ID}/dagRuns/nonexistent_run/taskInstances/{TASK_ID}/states/job_id"