Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/flareio/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests

from requests.adapters import HTTPAdapter
from requests.auth import AuthBase
from urllib3.util import Retry

import typing as t
Expand All @@ -34,7 +35,7 @@ def __init__(
tenant_id: t.Optional[int] = None,
session: t.Optional[requests.Session] = None,
api_domain: t.Optional[str] = None,
_disable_auth: bool = False,
_auth: AuthBase | None = None,
_enable_beta_features: bool = False,
) -> None:
if not api_key:
Expand All @@ -52,9 +53,9 @@ def __init__(
self._api_key: str = api_key
self._tenant_id: t.Optional[int] = tenant_id

self._auth: t.Optional[AuthBase] = _auth
self._api_token: t.Optional[str] = None
self._api_token_exp: t.Optional[datetime] = None
self._disable_auth: bool = _disable_auth
self._session = session or self._create_session()

@classmethod
Expand Down Expand Up @@ -135,16 +136,24 @@ def generate_token(self) -> str:

return token

def _auth_headers(self) -> dict:
if self._disable_auth:
return dict()
def _apply_auth(
self,
*,
request: requests.PreparedRequest,
) -> requests.PreparedRequest:
if self._auth:
self._auth(request)
return request

api_token: t.Optional[str] = self._api_token
if not api_token or (
self._api_token_exp and self._api_token_exp < datetime.now()
):
api_token = self.generate_token()

return {"Authorization": f"Bearer {api_token}"}
request.headers["Authorization"] = f"Bearer {api_token}"

return request

def _request(
self,
Expand All @@ -163,19 +172,20 @@ def _request(
f"Client was used to access {netloc=} at {url=}. Only the domain {self._api_domain} is supported."
)

headers = {
**(headers or {}),
**self._auth_headers(),
}

return self._session.request(
request = requests.Request(
method=method,
url=url,
params=params,
json=json,
headers=headers,
)

prepared = self._session.prepare_request(request)
prepared = self._apply_auth(request=prepared)
resp = self._session.send(prepared)

return resp

def post(
self,
url: str,
Expand Down
26 changes: 26 additions & 0 deletions src/flareio/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from requests import PreparedRequest
from requests.auth import AuthBase


class _StaticHeadersAuth(AuthBase):
def __init__(
self,
*,
headers: dict[str, str],
) -> None:
self._headers: dict[str, str] = headers

def __call__(
self,
r: PreparedRequest,
) -> PreparedRequest:
r.headers.update(self._headers)
return r


class _EmptyAuth(AuthBase):
def __call__(
self,
r: PreparedRequest,
) -> PreparedRequest:
return r
50 changes: 50 additions & 0 deletions tests/test_api_client_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import requests_mock

from .utils import get_test_client

from flareio.auth import _EmptyAuth
from flareio.auth import _StaticHeadersAuth


def test_custom_auth_empty() -> None:
client = get_test_client(
authenticated=False,
_auth=_EmptyAuth(),
)
with requests_mock.Mocker() as mocker:
mocker.register_uri(
"POST",
"https://api.flare.io/hello-post",
status_code=200,
)
client.post(
"https://api.flare.io/hello-post",
json={"foo": "bar"},
)
assert not mocker.last_request.headers.get("Authorization")


def test_custom_auth_static() -> None:
client = get_test_client(
authenticated=False,
_auth=_StaticHeadersAuth(
headers={
"first-header": "first-value",
"Authorization": "auth-value",
}
),
)
with requests_mock.Mocker() as mocker:
mocker.register_uri(
"POST",
"https://api.flare.io/hello-post",
status_code=200,
)
client.post(
"https://api.flare.io/hello-post",
json={"foo": "bar"},
headers={"second-header": "second-value"},
)
assert mocker.last_request.headers["Authorization"] == "auth-value"
assert mocker.last_request.headers["first-header"] == "first-value"
assert mocker.last_request.headers["second-header"] == "second-value"
19 changes: 0 additions & 19 deletions tests/test_api_client_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,22 +121,3 @@ def test_bad_domain() -> None:
match="Client was used to access netloc='bad.com' at url='https://bad.com/hello-post'. Only the domain api.flare.io is supported.",
):
client.post("https://bad.com/hello-post")


def test_disable_auth_does_not_call_generate() -> None:
client = get_test_client(
authenticated=False,
_disable_auth=True,
)
with requests_mock.Mocker() as mocker:
mocker.register_uri(
"POST",
"https://api.flare.io/hello-post",
status_code=200,
)
client.post("https://api.flare.io/hello-post", json={"foo": "bar"})
assert mocker.last_request.url == "https://api.flare.io/hello-post"
assert mocker.last_request.json() == {"foo": "bar"}

# Authorization header should not be present when auth is disabled
assert not mocker.last_request.headers.get("Authorization")
6 changes: 4 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import requests_mock

from requests.auth import AuthBase

import typing as t

from flareio import FlareApiClient
Expand All @@ -11,14 +13,14 @@ def get_test_client(
authenticated: bool = True,
api_domain: t.Optional[str] = None,
_enable_beta_features: bool = False,
_disable_auth: bool = False,
_auth: t.Optional[AuthBase] = None,
) -> FlareApiClient:
client = FlareApiClient(
api_key="test-api-key",
tenant_id=tenant_id,
api_domain=api_domain,
_enable_beta_features=_enable_beta_features,
_disable_auth=_disable_auth,
_auth=_auth,
)

if authenticated:
Expand Down