From b2c7d3302839f533374ae6fc474387b29c6d2208 Mon Sep 17 00:00:00 2001 From: Anusha Kovi Date: Tue, 28 Oct 2025 22:45:39 -0700 Subject: [PATCH 1/3] Fix MyPy type errors in api_fastapi remaining files for Sqlalchemy 2 migration --- .../src/airflow/api_fastapi/common/db/dag_runs.py | 4 ++-- .../src/airflow/api_fastapi/common/parameters.py | 14 ++++++++++---- .../core_api/services/public/dag_run.py | 4 ++-- .../api_fastapi/core_api/routes/ui/test_dags.py | 2 +- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/common/db/dag_runs.py b/airflow-core/src/airflow/api_fastapi/common/db/dag_runs.py index 1ca5f1d261b54..9e016e4e7351f 100644 --- a/airflow-core/src/airflow/api_fastapi/common/db/dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/common/db/dag_runs.py @@ -23,11 +23,11 @@ from airflow.models.dagrun import DagRun dagruns_select_with_state_count = ( - select( + select( # type: ignore[call-overload] DagRun.dag_id, DagRun.state, DagModel.dag_display_name, - func.count(DagRun.state), + func.count(DagRun.state).label("count"), ) .join(DagModel, DagRun.dag_id == DagModel.dag_id) .group_by(DagRun.dag_id, DagRun.state, DagModel.dag_display_name) diff --git a/airflow-core/src/airflow/api_fastapi/common/parameters.py b/airflow-core/src/airflow/api_fastapi/common/parameters.py index 6674dd55fef09..5a235dff62e31 100644 --- a/airflow-core/src/airflow/api_fastapi/common/parameters.py +++ b/airflow-core/src/airflow/api_fastapi/common/parameters.py @@ -66,6 +66,8 @@ if TYPE_CHECKING: from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import ColumnElement, Select +else: + from sqlalchemy.orm.attributes import InstrumentedAttribute T = TypeVar("T") @@ -75,7 +77,7 @@ class BaseParam(OrmClause[T], ABC): def __init__(self, value: T | None = None, skip_none: bool = True) -> None: super().__init__(value) - self.attribute: ColumnElement | None = None + self.attribute: ColumnElement | InstrumentedAttribute | None = None self.skip_none = skip_none def set_value(self, value: T | None) -> Self: @@ -387,7 +389,7 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self: def filter_param_factory( - attribute: ColumnElement, + attribute: ColumnElement | InstrumentedAttribute, _type: type, filter_option: FilterOptionEnum = FilterOptionEnum.EQUAL, filter_name: str | None = None, @@ -399,7 +401,7 @@ def filter_param_factory( description: str | None = None, ) -> Callable[[T | None], FilterParam[T | None]]: # if filter_name is not provided, use the attribute name as the default - filter_name = filter_name or attribute.name + filter_name = filter_name or getattr(attribute, "name", str(attribute)) # can only set either default_value or default_factory query = ( Query(alias=filter_name, default_factory=default_factory, description=description) @@ -410,7 +412,11 @@ def filter_param_factory( def depends_filter(value: T | None = query) -> FilterParam[T | None]: if transform_callable: value = transform_callable(value) - return FilterParam(attribute, value, filter_option, skip_none) + # Cast to InstrumentedAttribute for type compatibility + from typing import cast + + attr = cast("InstrumentedAttribute", attribute) + return FilterParam(attr, value, filter_option, skip_none) # add type hint to value at runtime depends_filter.__annotations__["value"] = _type diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py index 37a5d761b3033..e2521b72eb5a3 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py @@ -57,7 +57,7 @@ def _serialize_xcoms(self) -> dict[str, Any]: task_ids=self.result_task_ids, dag_ids=self.dag_id, ) - xcom_query = self.session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index)).all() + xcom_results = self.session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index)).all() def _group_xcoms(g: Iterator[XComModel]) -> Any: entries = list(g) @@ -67,7 +67,7 @@ def _group_xcoms(g: Iterator[XComModel]) -> Any: return { task_id: _group_xcoms(g) - for task_id, g in itertools.groupby(xcom_query, key=operator.attrgetter("task_id")) + for task_id, g in itertools.groupby(xcom_results, key=operator.attrgetter("task_id")) } def _serialize_response(self, dag_run: DagRun) -> str: diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py index 2d1a896542d8a..5f2c0e9c80843 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py @@ -69,7 +69,7 @@ def setup_dag_runs(self, session=None) -> None: triggered_by=DagRunTriggeredByType.TEST, ) if dag_run.start_date is not None: - dag_run.end_date = dag_run.start_date.add(hours=1) + dag_run.end_date = dag_run.start_date + pendulum.duration(hours=1) session.add(dag_run) session.commit() From 005381781babe0e73cfca063bac7135411f490e7 Mon Sep 17 00:00:00 2001 From: Anusha Kovi Date: Wed, 29 Oct 2025 20:23:18 -0700 Subject: [PATCH 2/3] add imports at the top --- airflow-core/src/airflow/api_fastapi/common/parameters.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/common/parameters.py b/airflow-core/src/airflow/api_fastapi/common/parameters.py index 5a235dff62e31..ffd5f554da15e 100644 --- a/airflow-core/src/airflow/api_fastapi/common/parameters.py +++ b/airflow-core/src/airflow/api_fastapi/common/parameters.py @@ -28,6 +28,7 @@ Generic, Literal, TypeVar, + cast, overload, ) @@ -413,8 +414,6 @@ def depends_filter(value: T | None = query) -> FilterParam[T | None]: if transform_callable: value = transform_callable(value) # Cast to InstrumentedAttribute for type compatibility - from typing import cast - attr = cast("InstrumentedAttribute", attribute) return FilterParam(attr, value, filter_option, skip_none) From 53c6ec077b927339802a62be6964b8b563ed5f7b Mon Sep 17 00:00:00 2001 From: Anusha Kovi Date: Wed, 29 Oct 2025 21:49:18 -0700 Subject: [PATCH 3/3] remove duplicate import --- airflow-core/src/airflow/api_fastapi/common/db/dag_runs.py | 7 +++++-- airflow-core/src/airflow/api_fastapi/common/parameters.py | 3 +-- .../api_fastapi/core_api/services/public/dag_run.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/common/db/dag_runs.py b/airflow-core/src/airflow/api_fastapi/common/db/dag_runs.py index 9e016e4e7351f..2efb6967653e8 100644 --- a/airflow-core/src/airflow/api_fastapi/common/db/dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/common/db/dag_runs.py @@ -17,17 +17,20 @@ from __future__ import annotations +from typing import cast + from sqlalchemy import func, select +from sqlalchemy.sql import ColumnElement from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun dagruns_select_with_state_count = ( - select( # type: ignore[call-overload] + select( DagRun.dag_id, DagRun.state, DagModel.dag_display_name, - func.count(DagRun.state).label("count"), + cast("ColumnElement[int]", func.count(DagRun.state).label("count")), ) .join(DagModel, DagRun.dag_id == DagModel.dag_id) .group_by(DagRun.dag_id, DagRun.state, DagModel.dag_display_name) diff --git a/airflow-core/src/airflow/api_fastapi/common/parameters.py b/airflow-core/src/airflow/api_fastapi/common/parameters.py index ffd5f554da15e..dd7b0d808a7cf 100644 --- a/airflow-core/src/airflow/api_fastapi/common/parameters.py +++ b/airflow-core/src/airflow/api_fastapi/common/parameters.py @@ -67,8 +67,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import ColumnElement, Select -else: - from sqlalchemy.orm.attributes import InstrumentedAttribute + T = TypeVar("T") diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py index e2521b72eb5a3..5a08ed1c3b065 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py @@ -57,7 +57,7 @@ def _serialize_xcoms(self) -> dict[str, Any]: task_ids=self.result_task_ids, dag_ids=self.dag_id, ) - xcom_results = self.session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index)).all() + xcom_results = self.session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index)) def _group_xcoms(g: Iterator[XComModel]) -> Any: entries = list(g)