Skip to content

Commit

Permalink
feat(Client): expose client extra headers in init function (#1715)
Browse files Browse the repository at this point in the history
* feat(Client): expose client extra headers in init function

* tests: adding missing tests

* docs: include some example of how to use additional headers

* Apply suggestions from code review

Co-authored-by: Daniel Vila Suero <daniel@recogn.ai>

Co-authored-by: Daniel Vila Suero <daniel@recogn.ai>

Closes #1706

(cherry picked from commit 994494f)
  • Loading branch information
frascuchon committed Sep 30, 2022
1 parent 7e3f708 commit 9abf784
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
15 changes: 13 additions & 2 deletions src/rubrix/client/api.py
Expand Up @@ -113,6 +113,7 @@ def __init__(
api_key: Optional[str] = None,
workspace: Optional[str] = None,
timeout: int = 60,
extra_headers: Optional[Dict[str, str]] = None,
):
"""Init the Python client.
Expand All @@ -127,22 +128,32 @@ def __init__(
workspace: The workspace to which records will be logged/loaded. If `None` (default) and the
env variable ``RUBRIX_WORKSPACE`` is not set, it will default to the private user workspace.
timeout: Wait `timeout` seconds for the connection to timeout. Default: 60.
extra_headers: Extra HTTP headers sent to the server. You can use this to customize
the headers of Rubrix client requests, like additional security restrictions. Default: `None`.
Examples:
>>> import rubrix as rb
>>> rb.init(api_url="http://localhost:9090", api_key="4AkeAPIk3Y")
>>> # Customizing request headers
>>> headers = {"X-Client-id":"id","X-Secret":"secret"}
>>> rb.init(api_url="http://localhost:9090", api_key="4AkeAPIk3Y", extra_headers=headers)
"""
api_url = api_url or os.getenv("RUBRIX_API_URL", "http://localhost:6900")
# Checking that the api_url does not end in '/'
api_url = re.sub(r"\/$", "", api_url)
api_key = api_key or os.getenv("RUBRIX_API_KEY", DEFAULT_API_KEY)
workspace = workspace or os.getenv("RUBRIX_WORKSPACE")
headers = extra_headers or {}

self._client: AuthenticatedClient = AuthenticatedClient(
base_url=api_url, token=api_key, timeout=timeout
base_url=api_url,
token=api_key,
timeout=timeout,
headers=headers.copy(),
)
self._user: User = users_api.whoami(client=self._client)

self._user: User = users_api.whoami(client=self._client)
if workspace is not None:
self.set_workspace(workspace)

Expand Down
18 changes: 16 additions & 2 deletions tests/client/test_init.py
Expand Up @@ -15,8 +15,8 @@
from rubrix.client import api


def test_resource_leaking_with_several_inits(mocked_client):
dataset = "test_resource_leaking_with_several_inits"
def test_resource_leaking_with_several_init(mocked_client):
dataset = "test_resource_leaking_with_several_init"
api.delete(dataset)

# TODO: review performance in Windows. See https://github.com/recognai/rubrix/pull/1702
Expand All @@ -30,3 +30,17 @@ def test_resource_leaking_with_several_inits(mocked_client):
)

assert len(api.load(dataset)) == 10


def test_init_with_extra_headers(mocked_client):
expected_headers = {
"X-Custom-Header": "Mocking rules!",
"Other-header": "Header value",
}
api.init(extra_headers=expected_headers)
active_api = api.active_api()

for key, value in expected_headers.items():
assert (
active_api.client.headers[key] == value
), f"{key}:{value} not in client headers"

0 comments on commit 9abf784

Please sign in to comment.