diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 8dbdf25f7e014..2dbdb2f4a654a 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -30,6 +30,7 @@ ) from fastapi import FastAPI, Request, Response from fastapi.responses import JSONResponse +from sqlalchemy.exc import SQLAlchemyError from starlette.middleware.base import BaseHTTPMiddleware from airflow.api_fastapi.auth.tokens import ( @@ -271,6 +272,19 @@ def handle_exceptions(request: Request, exc: Exception): content["correlation-id"] = correlation_id return JSONResponse(status_code=500, content=content) + @app.exception_handler(SQLAlchemyError) + def handle_database_exceptions(request: Request, exc: SQLAlchemyError): + logger.exception( + "Database error handling request", + path=request.url.path, + method=request.method, + exc_info=(type(exc), exc, exc.__traceback__), + ) + content: dict[str, str] = {"detail": "Database error occurred"} + if correlation_id := request.headers.get("correlation-id"): + content["correlation-id"] = correlation_id + return JSONResponse(status_code=500, content=content) + return app diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index b9242583e6977..a30846272e7f5 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -35,7 +35,7 @@ from pydantic import JsonValue from sqlalchemy import and_, func, or_, tuple_, update from sqlalchemy.engine import CursorResult -from sqlalchemy.exc import NoResultFound, SQLAlchemyError +from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import joinedload from sqlalchemy.sql import select from structlog.contextvars import bind_contextvars @@ -242,74 +242,66 @@ def ti_run( retry_reason=None, ) - try: - result = session.execute(query) - log.info("Task instance state updated", rows_affected=getattr(result, "rowcount", 0)) - - dr = ( - session.scalars( - select(DR) - .filter_by(dag_id=ti.dag_id, run_id=ti.run_id) - .options(joinedload(DR.consumed_asset_events)) - ) - .unique() - .one_or_none() - ) + result = session.execute(query) + log.info("Task instance state updated", rows_affected=getattr(result, "rowcount", 0)) - if not dr: - log.error("DagRun not found", dag_id=ti.dag_id, run_id=ti.run_id) - raise ValueError(f"DagRun with dag_id={ti.dag_id} and run_id={ti.run_id} not found.") - - # Send the keys to the SDK so that the client requests to clear those XComs from the server. - # The reason we cannot do this here in the server is because we need to issue a purge on custom XCom backends - # too. With the current assumption, the workers ONLY have access to the custom XCom backends directly and they - # can issue the purge. - - # However, do not clear it for deferral - xcom_keys = [] - if not ti.next_method: - map_index = None if ti.map_index < 0 else ti.map_index - xcom_query = select(XComModel.key).where( - XComModel.dag_id == ti.dag_id, - XComModel.task_id == ti.task_id, - XComModel.run_id == ti.run_id, - ) - if map_index is not None: - xcom_query = xcom_query.where(XComModel.map_index == map_index) + dr = ( + session.scalars( + select(DR) + .filter_by(dag_id=ti.dag_id, run_id=ti.run_id) + .options(joinedload(DR.consumed_asset_events)) + ) + .unique() + .one_or_none() + ) - xcom_keys = list(session.scalars(xcom_query)) - task_reschedule_count = ( - session.scalar( - select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == task_instance_id) - ) - or 0 + if not dr: + log.error("DagRun not found", dag_id=ti.dag_id, run_id=ti.run_id) + raise ValueError(f"DagRun with dag_id={ti.dag_id} and run_id={ti.run_id} not found.") + + # Send the keys to the SDK so that the client requests to clear those XComs from the server. + # The reason we cannot do this here in the server is because we need to issue a purge on custom XCom backends + # too. With the current assumption, the workers ONLY have access to the custom XCom backends directly and they + # can issue the purge. + + # However, do not clear it for deferral + xcom_keys = [] + if not ti.next_method: + map_index = None if ti.map_index < 0 else ti.map_index + xcom_query = select(XComModel.key).where( + XComModel.dag_id == ti.dag_id, + XComModel.task_id == ti.task_id, + XComModel.run_id == ti.run_id, ) + if map_index is not None: + xcom_query = xcom_query.where(XComModel.map_index == map_index) - from airflow.api_fastapi.execution_api.security import get_team_name_for_ti + xcom_keys = list(session.scalars(xcom_query)) + task_reschedule_count = ( + session.scalar(select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == task_instance_id)) + or 0 + ) - dr.team_name = get_team_name_for_ti(task_instance_id, session) + from airflow.api_fastapi.execution_api.security import get_team_name_for_ti - context = TIRunContext( - dag_run=dr, - task_reschedule_count=task_reschedule_count, - max_tries=ti.max_tries, - # TODO: Add variables and connections that are needed (and has perms) for the task - variables=[], - connections=[], - xcom_keys_to_clear=xcom_keys, - should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries), - ) + dr.team_name = get_team_name_for_ti(task_instance_id, session) - # Only set if they are non-null - if ti.next_method: - context.next_method = ti.next_method - context.next_kwargs = ti.next_kwargs - context.start_date = ti.start_date - except SQLAlchemyError: - log.exception("Error marking Task Instance state as running") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred" - ) + context = TIRunContext( + dag_run=dr, + task_reschedule_count=task_reschedule_count, + max_tries=ti.max_tries, + # TODO: Add variables and connections that are needed (and has perms) for the task + variables=[], + connections=[], + xcom_keys_to_clear=xcom_keys, + should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries), + ) + + # Only set if they are non-null + if ti.next_method: + context.next_method = ti.next_method + context.next_kwargs = ti.next_kwargs + context.start_date = ti.start_date # JWTReissueMiddleware also writes Refreshed-API-Token but skips workload tokens, so we set it here for the workload→execution swap. if token.claims.scope == "workload": @@ -435,33 +427,25 @@ def ti_update_state( if ti is not None: _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, dag_bag=dag_bag) - # TODO: Replace this with FastAPI's Custom Exception handling: - # https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers - try: - result = session.execute(query) - log.info( - "Task instance state updated", - new_state=updated_state, - rows_affected=getattr(result, "rowcount", 0), - ) - session.add( - Log( - event=updated_state.value, - task_id=task_id, - dag_id=dag_id, - run_id=run_id, - map_index=map_index, - try_number=try_number, - logical_date=logical_date, - owner=owners, - extra=json.dumps({"host_name": hostname}) if hostname else None, - ) - ) - except SQLAlchemyError as e: - log.error("Error updating Task Instance state", error=str(e)) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred" + result = session.execute(query) + log.info( + "Task instance state updated", + new_state=updated_state, + rows_affected=getattr(result, "rowcount", 0), + ) + session.add( + Log( + event=updated_state.value, + task_id=task_id, + dag_id=dag_id, + run_id=run_id, + map_index=map_index, + try_number=try_number, + logical_date=logical_date, + owner=owners, + extra=json.dumps({"host_name": hostname}) if hostname else None, ) + ) if updated_state == TaskInstanceState.SUCCESS: if conf.getboolean("state_store", "clear_on_success"):