From 584b46d2c64daeab1e71844c00eaa2474717fa56 Mon Sep 17 00:00:00 2001 From: Hemkumar Chheda Date: Thu, 14 May 2026 13:04:12 +0530 Subject: [PATCH] Avoid false trigger Dag run conflicts after ambiguous retry closes: #66905 --- task-sdk/src/airflow/sdk/api/client.py | 42 +++++++++++++--- task-sdk/tests/task_sdk/api/test_client.py | 57 ++++++++++++++++++++++ 2 files changed, 91 insertions(+), 8 deletions(-) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 269978ac9dd1d..0bffa6b136137 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -870,9 +870,18 @@ def trigger( ) try: - self.client.post( - f"dag-runs/{dag_id}/{run_id}", content=body.model_dump_json(exclude_defaults=True) + self.client._request_without_retry( + "POST", f"dag-runs/{dag_id}/{run_id}", content=body.model_dump_json(exclude_defaults=True) ) + except httpx.RequestError: + if not reset_dag_run and self._dag_run_exists(dag_id=dag_id, run_id=run_id): + log.info( + "Dag Run exists after ambiguous trigger response; treating trigger as successful.", + dag_id=dag_id, + run_id=run_id, + ) + return OKResponse(ok=True) + raise except ServerResponseError as e: if e.response.status_code == HTTPStatus.CONFLICT: if reset_dag_run: @@ -885,6 +894,15 @@ def trigger( return OKResponse(ok=True) + def _dag_run_exists(self, dag_id: str, run_id: str) -> bool: + try: + self.client.get(f"dag-runs/{dag_id}/{run_id}") + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + return False + raise + return True + def clear(self, dag_id: str, run_id: str) -> OKResponse: """Clear a Dag run via the API server.""" self.client.post(f"dag-runs/{dag_id}/{run_id}/clear") @@ -1125,6 +1143,19 @@ def _update_auth(self, response: httpx.Response): log.debug("Execution API issued us a refreshed Task token") self.auth = BearerAuth(new_token) + @staticmethod + def _ensure_json_content_type(kwargs: dict[str, Any]) -> None: + # Set content type as convenience if not already set + if kwargs.get("content", None) is not None and "content-type" not in ( + kwargs.get("headers", {}) or {} + ): + kwargs["headers"] = {"content-type": "application/json"} + + def _request_without_retry(self, *args, **kwargs): + """Implement a convenience for httpx.Client.request without retrying.""" + self._ensure_json_content_type(kwargs) + return super().request(*args, **kwargs) + @retry( retry=retry_if_exception(_should_retry_api_request), stop=stop_after_attempt(API_RETRIES), @@ -1134,12 +1165,7 @@ def _update_auth(self, response: httpx.Response): ) def request(self, *args, **kwargs): """Implement a convenience for httpx.Client.request with a retry layer.""" - # Set content type as convenience if not already set - if kwargs.get("content", None) is not None and "content-type" not in ( - kwargs.get("headers", {}) or {} - ): - kwargs["headers"] = {"content-type": "application/json"} - + self._ensure_json_content_type(kwargs) return super().request(*args, **kwargs) # We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index a179ff08436b2..94f4fc278ef0c 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -1280,6 +1280,63 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert result == OKResponse(ok=True) + def test_trigger_treats_ambiguous_request_error_as_success_when_dag_run_exists(self): + requests: list[tuple[str, str]] = [] + + def handle_request(request: httpx.Request) -> httpx.Response: + requests.append((request.method, request.url.path)) + if request.method == "POST" and request.url.path == "/dag-runs/test_trigger/test_run_id": + raise httpx.ReadError("Trigger response was lost", request=request) + if request.method == "GET" and request.url.path == "/dag-runs/test_trigger/test_run_id": + return httpx.Response(status_code=200, json={"detail": "exists"}) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.trigger(dag_id="test_trigger", run_id="test_run_id") + + assert result == OKResponse(ok=True) + assert requests == [ + ("POST", "/dag-runs/test_trigger/test_run_id"), + ("GET", "/dag-runs/test_trigger/test_run_id"), + ] + + def test_trigger_reraises_ambiguous_request_error_when_dag_run_is_missing(self): + requests: list[tuple[str, str]] = [] + + def handle_request(request: httpx.Request) -> httpx.Response: + requests.append((request.method, request.url.path)) + if request.method == "POST" and request.url.path == "/dag-runs/test_trigger/test_run_id": + raise httpx.ReadError("Trigger response was lost", request=request) + if request.method == "GET" and request.url.path == "/dag-runs/test_trigger/test_run_id": + return httpx.Response(status_code=404, json={"detail": "Dag run not found"}) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + + with pytest.raises(httpx.ReadError, match="Trigger response was lost"): + client.dag_runs.trigger(dag_id="test_trigger", run_id="test_run_id") + + assert requests == [ + ("POST", "/dag-runs/test_trigger/test_run_id"), + ("GET", "/dag-runs/test_trigger/test_run_id"), + ] + + def test_trigger_reraises_ambiguous_request_error_when_resetting_dag_run(self): + requests: list[tuple[str, str]] = [] + + def handle_request(request: httpx.Request) -> httpx.Response: + requests.append((request.method, request.url.path)) + if request.method == "POST" and request.url.path == "/dag-runs/test_trigger/test_run_id": + raise httpx.ReadError("Trigger response was lost", request=request) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + + with pytest.raises(httpx.ReadError, match="Trigger response was lost"): + client.dag_runs.trigger(dag_id="test_trigger", run_id="test_run_id", reset_dag_run=True) + + assert requests == [("POST", "/dag-runs/test_trigger/test_run_id")] + def test_trigger_conflict(self): """Test that if the dag run already exists, the client returns an error when default reset_dag_run=False"""