diff --git a/src/rubrix/client/api.py b/src/rubrix/client/api.py index 18b10a80d0..e3078df451 100644 --- a/src/rubrix/client/api.py +++ b/src/rubrix/client/api.py @@ -113,6 +113,7 @@ def __init__( api_key: Optional[str] = None, workspace: Optional[str] = None, timeout: int = 60, + extra_headers: Optional[Dict[str, str]] = None, ): """Init the Python client. @@ -127,22 +128,32 @@ def __init__( workspace: The workspace to which records will be logged/loaded. If `None` (default) and the env variable ``RUBRIX_WORKSPACE`` is not set, it will default to the private user workspace. timeout: Wait `timeout` seconds for the connection to timeout. Default: 60. + extra_headers: Extra HTTP headers sent to the server. You can use this to customize + the headers of Rubrix client requests, like additional security restrictions. Default: `None`. Examples: >>> import rubrix as rb >>> rb.init(api_url="http://localhost:9090", api_key="4AkeAPIk3Y") + >>> # Customizing request headers + >>> headers = {"X-Client-id":"id","X-Secret":"secret"} + >>> rb.init(api_url="http://localhost:9090", api_key="4AkeAPIk3Y", extra_headers=headers) + """ api_url = api_url or os.getenv("RUBRIX_API_URL", "http://localhost:6900") # Checking that the api_url does not end in '/' api_url = re.sub(r"\/$", "", api_url) api_key = api_key or os.getenv("RUBRIX_API_KEY", DEFAULT_API_KEY) workspace = workspace or os.getenv("RUBRIX_WORKSPACE") + headers = extra_headers or {} self._client: AuthenticatedClient = AuthenticatedClient( - base_url=api_url, token=api_key, timeout=timeout + base_url=api_url, + token=api_key, + timeout=timeout, + headers=headers.copy(), ) - self._user: User = users_api.whoami(client=self._client) + self._user: User = users_api.whoami(client=self._client) if workspace is not None: self.set_workspace(workspace) diff --git a/tests/client/test_init.py b/tests/client/test_init.py index 73a4b02012..cb5b9a5dbc 100644 --- a/tests/client/test_init.py +++ b/tests/client/test_init.py @@ -15,8 +15,8 @@ from rubrix.client import api -def test_resource_leaking_with_several_inits(mocked_client): - dataset = "test_resource_leaking_with_several_inits" +def test_resource_leaking_with_several_init(mocked_client): + dataset = "test_resource_leaking_with_several_init" api.delete(dataset) # TODO: review performance in Windows. See https://github.com/recognai/rubrix/pull/1702 @@ -30,3 +30,17 @@ def test_resource_leaking_with_several_inits(mocked_client): ) assert len(api.load(dataset)) == 10 + + +def test_init_with_extra_headers(mocked_client): + expected_headers = { + "X-Custom-Header": "Mocking rules!", + "Other-header": "Header value", + } + api.init(extra_headers=expected_headers) + active_api = api.active_api() + + for key, value in expected_headers.items(): + assert ( + active_api.client.headers[key] == value + ), f"{key}:{value} not in client headers"