Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def date_param():
]


@pytest.mark.flaky(reruns=3, reruns_delay=1)
@pytest.mark.parametrize(
"command",
TEST_COMMANDS_DEBUG_MODE,
Expand All @@ -144,7 +143,6 @@ def test_airflowctl_commands(command: str, run_command):
run_command(command, env_vars, skip_login=True)


@pytest.mark.flaky(reruns=3, reruns_delay=1)
@pytest.mark.parametrize(
"command",
TEST_COMMANDS_SKIP_KEYRING,
Expand Down
18 changes: 18 additions & 0 deletions airflow-ctl/docs/cli-and-env-variables-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,21 @@ Environment Variables
It disables some features such as keyring integration and save credentials to file.
It is only meant to use if either you are developing airflowctl or running API integration tests.
Please do not use this variable unless you know what you are doing.

.. envvar:: AIRFLOW_CLI_API_RETRIES

The number of times to retry an API call if it fails. This is
only used if you are using the Airflow API and have not set up
authentication using a different method. The default value is 3.

.. envvar:: AIRFLOW_CLI_API_RETRY_WAIT_MIN

The minimum amount of time to wait between API retries in seconds.
This is only used if you are using the Airflow API and have not set up
authentication using a different method. The default value is 1 second.

.. envvar:: AIRFLOW_CLI_API_RETRY_WAIT_MAX

The maximum amount of time to wait between API retries in seconds.
This is only used if you are using the Airflow API and have not set up
authentication using a different method. The default value is 10 seconds.
1 change: 1 addition & 0 deletions airflow-ctl/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"structlog>=25.4.0",
"uuid6>=2024.7.10",
"tabulate>=0.9.0",
"tenacity>=9.1.4",
]

classifiers = [
Expand Down
39 changes: 39 additions & 0 deletions airflow-ctl/src/airflowctl/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import enum
import getpass
import json
import logging
import os
import sys
from collections.abc import Callable
Expand All @@ -32,6 +33,13 @@
import structlog
from httpx import URL
from keyring.errors import NoKeyringError
from tenacity import (
before_log,
retry,
retry_if_exception,
stop_after_attempt,
wait_random_exponential,
)
from uuid6 import uuid7

from airflowctl import __version__ as version
Expand Down Expand Up @@ -261,6 +269,20 @@ def auth_flow(self, request: httpx.Request):
yield request


def _should_retry_api_request(exception: BaseException) -> bool:
"""Determine if an API request should be retried based on the exception type."""
if isinstance(exception, httpx.HTTPStatusError):
return exception.response.status_code >= 500

return isinstance(exception, httpx.RequestError)


# API Client Retry Configuration
API_RETRIES = int(os.getenv("AIRFLOW_CLI_API_RETRIES", "3"))
API_RETRY_WAIT_MIN = int(os.getenv("AIRFLOW_CLI_API_RETRY_WAIT_MIN", "1"))
API_RETRY_WAIT_MAX = int(os.getenv("AIRFLOW_CLI_API_RETRY_WAIT_MAX", "10"))


class Client(httpx.Client):
"""Client for the Airflow REST API."""

Expand Down Expand Up @@ -298,6 +320,23 @@ def _get_base_url(
return f"{base_url}/auth"
return f"{base_url}/api/v2"

@retry(
retry=retry_if_exception(_should_retry_api_request),
stop=stop_after_attempt(API_RETRIES),
wait=wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX),
before_sleep=before_log(log, logging.WARNING),
reraise=True,
)
def request(self, *args, **kwargs):
"""Implement a convenience for httpx.Client.request with a retry layer."""
# Set content type as convenience if not already set
if kwargs.get("content", None) is not None and "content-type" not in (
kwargs.get("headers", {}) or {}
):
kwargs["headers"] = {"content-type": "application/json"}

return super().request(*args, **kwargs)

@lru_cache() # type: ignore[prop-decorator]
@property
def login(self):
Expand Down
62 changes: 62 additions & 0 deletions airflow-ctl/tests/airflow_ctl/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,23 @@

import httpx
import pytest
import time_machine
from httpx import URL

from airflowctl.api.client import Client, ClientKind, Credentials, _bounded_get_new_password
from airflowctl.api.operations import ServerResponseError
from airflowctl.exceptions import AirflowCtlCredentialNotFoundException, AirflowCtlKeyringException


def make_client_w_responses(responses: list[httpx.Response]) -> Client:
"""Get a client with custom responses."""

def handle_request(request: httpx.Request) -> httpx.Response:
return responses.pop(0)

return Client(base_url="", token="", mounts={"'http://": httpx.MockTransport(handle_request)})


@pytest.fixture(autouse=True)
def unique_config_dir():
temp_dir = tempfile.mkdtemp()
Expand Down Expand Up @@ -314,3 +324,55 @@ def test_save_skips_patch_for_non_encrypted_backend(self, mock_keyring):

assert not hasattr(mock_backend, "_get_new_password")
mock_keyring.set_password.assert_called_once_with("airflowctl", "api_token_production", "token")

def test_retry_handling_unrecoverable_error(self):
with time_machine.travel("2023-01-01T00:00:00Z", tick=False):
responses: list[httpx.Response] = [
*[httpx.Response(500, text="Internal Server Error")] * 6,
httpx.Response(200, json={"detail": "Recovered from error - but will fail before"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

with pytest.raises(httpx.HTTPStatusError) as err:
client.get("http://error")
assert not isinstance(err.value, ServerResponseError)
assert len(responses) == 5

def test_retry_handling_recovered(self):
with time_machine.travel("2023-01-01T00:00:00Z", tick=False):
responses: list[httpx.Response] = [
*[httpx.Response(500, text="Internal Server Error")] * 2,
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1

def test_retry_handling_non_retry_error(self):
with time_machine.travel("2023-01-01T00:00:00Z", tick=False):
responses: list[httpx.Response] = [
httpx.Response(422, json={"detail": "Somehow this is a bad request"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

with pytest.raises(ServerResponseError) as err:
client.get("http://error")
assert len(responses) == 1
assert err.value.args == ("Client error message: {'detail': 'Somehow this is a bad request'}",)

def test_retry_handling_ok(self):
with time_machine.travel("2023-01-01T00:00:00Z", tick=False):
responses: list[httpx.Response] = [
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
Loading