Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,9 +870,18 @@ def trigger(
)

try:
self.client.post(
f"dag-runs/{dag_id}/{run_id}", content=body.model_dump_json(exclude_defaults=True)
self.client._request_without_retry(
"POST", f"dag-runs/{dag_id}/{run_id}", content=body.model_dump_json(exclude_defaults=True)
)
except httpx.RequestError:
if not reset_dag_run and self._dag_run_exists(dag_id=dag_id, run_id=run_id):
log.info(
"Dag Run exists after ambiguous trigger response; treating trigger as successful.",
dag_id=dag_id,
run_id=run_id,
)
return OKResponse(ok=True)
raise
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.CONFLICT:
if reset_dag_run:
Expand All @@ -885,6 +894,15 @@ def trigger(

return OKResponse(ok=True)

def _dag_run_exists(self, dag_id: str, run_id: str) -> bool:
try:
self.client.get(f"dag-runs/{dag_id}/{run_id}")
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
return False
raise
return True

def clear(self, dag_id: str, run_id: str) -> OKResponse:
"""Clear a Dag run via the API server."""
self.client.post(f"dag-runs/{dag_id}/{run_id}/clear")
Expand Down Expand Up @@ -1125,6 +1143,19 @@ def _update_auth(self, response: httpx.Response):
log.debug("Execution API issued us a refreshed Task token")
self.auth = BearerAuth(new_token)

@staticmethod
def _ensure_json_content_type(kwargs: dict[str, Any]) -> None:
# Set content type as convenience if not already set
if kwargs.get("content", None) is not None and "content-type" not in (
kwargs.get("headers", {}) or {}
):
kwargs["headers"] = {"content-type": "application/json"}

def _request_without_retry(self, *args, **kwargs):
"""Implement a convenience for httpx.Client.request without retrying."""
self._ensure_json_content_type(kwargs)
return super().request(*args, **kwargs)

@retry(
retry=retry_if_exception(_should_retry_api_request),
stop=stop_after_attempt(API_RETRIES),
Expand All @@ -1134,12 +1165,7 @@ def _update_auth(self, response: httpx.Response):
)
def request(self, *args, **kwargs):
"""Implement a convenience for httpx.Client.request with a retry layer."""
# Set content type as convenience if not already set
if kwargs.get("content", None) is not None and "content-type" not in (
kwargs.get("headers", {}) or {}
):
kwargs["headers"] = {"content-type": "application/json"}

self._ensure_json_content_type(kwargs)
return super().request(*args, **kwargs)

# We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all
Expand Down
57 changes: 57 additions & 0 deletions task-sdk/tests/task_sdk/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,63 @@ def handle_request(request: httpx.Request) -> httpx.Response:

assert result == OKResponse(ok=True)

def test_trigger_treats_ambiguous_request_error_as_success_when_dag_run_exists(self):
requests: list[tuple[str, str]] = []

def handle_request(request: httpx.Request) -> httpx.Response:
requests.append((request.method, request.url.path))
if request.method == "POST" and request.url.path == "/dag-runs/test_trigger/test_run_id":
raise httpx.ReadError("Trigger response was lost", request=request)
if request.method == "GET" and request.url.path == "/dag-runs/test_trigger/test_run_id":
return httpx.Response(status_code=200, json={"detail": "exists"})
return httpx.Response(status_code=422)

client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dag_runs.trigger(dag_id="test_trigger", run_id="test_run_id")

assert result == OKResponse(ok=True)
assert requests == [
("POST", "/dag-runs/test_trigger/test_run_id"),
("GET", "/dag-runs/test_trigger/test_run_id"),
]

def test_trigger_reraises_ambiguous_request_error_when_dag_run_is_missing(self):
requests: list[tuple[str, str]] = []

def handle_request(request: httpx.Request) -> httpx.Response:
requests.append((request.method, request.url.path))
if request.method == "POST" and request.url.path == "/dag-runs/test_trigger/test_run_id":
raise httpx.ReadError("Trigger response was lost", request=request)
if request.method == "GET" and request.url.path == "/dag-runs/test_trigger/test_run_id":
return httpx.Response(status_code=404, json={"detail": "Dag run not found"})
return httpx.Response(status_code=422)

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(httpx.ReadError, match="Trigger response was lost"):
client.dag_runs.trigger(dag_id="test_trigger", run_id="test_run_id")

assert requests == [
("POST", "/dag-runs/test_trigger/test_run_id"),
("GET", "/dag-runs/test_trigger/test_run_id"),
]

def test_trigger_reraises_ambiguous_request_error_when_resetting_dag_run(self):
requests: list[tuple[str, str]] = []

def handle_request(request: httpx.Request) -> httpx.Response:
requests.append((request.method, request.url.path))
if request.method == "POST" and request.url.path == "/dag-runs/test_trigger/test_run_id":
raise httpx.ReadError("Trigger response was lost", request=request)
return httpx.Response(status_code=422)

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(httpx.ReadError, match="Trigger response was lost"):
client.dag_runs.trigger(dag_id="test_trigger", run_id="test_run_id", reset_dag_run=True)

assert requests == [("POST", "/dag-runs/test_trigger/test_run_id")]

def test_trigger_conflict(self):
"""Test that if the dag run already exists, the client returns an error when default reset_dag_run=False"""

Expand Down
Loading