From ce6b8d8e29fbeedd6b4700f28af2095b385950eb Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 2 Aug 2022 15:39:17 +0200 Subject: [PATCH] fix(Client): reusing the inner `httpx` client (#1640) Closes #1646 (cherry picked from commit f9a80cc363b91bb29194f50e3e5129b1e0e36c27) --- src/rubrix/client/api.py | 18 +++--- src/rubrix/client/sdk/client.py | 55 +++++++++++-------- src/rubrix/client/sdk/text2text/api.py | 9 +-- .../client/sdk/text_classification/api.py | 12 ++-- .../client/sdk/token_classification/api.py | 11 +--- src/rubrix/client/sdk/users/api.py | 15 +---- src/rubrix/utils.py | 22 ++++++-- tests/client/sdk/conftest.py | 10 +++- tests/client/sdk/text2text/test_api.py | 7 +-- .../sdk/text_classification/test_api.py | 3 - .../sdk/token_classification/test_api.py | 2 - tests/client/sdk/users/test_api.py | 13 ++--- tests/client/test_api.py | 50 ++++------------- tests/client/test_init.py | 17 ++++++ tests/conftest.py | 46 ++++++++++++---- tests/helpers.py | 15 ++--- 16 files changed, 156 insertions(+), 149 deletions(-) create mode 100644 tests/client/test_init.py diff --git a/src/rubrix/client/api.py b/src/rubrix/client/api.py index c3e95f34db..5b2bdcb0ef 100644 --- a/src/rubrix/client/api.py +++ b/src/rubrix/client/api.py @@ -20,7 +20,7 @@ from asyncio import Future from functools import wraps from inspect import signature -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union from tqdm.auto import tqdm @@ -47,7 +47,7 @@ TokenClassificationRecord, ) from rubrix.client.sdk.client import AuthenticatedClient -from rubrix.client.sdk.commons.api import async_bulk, bulk +from rubrix.client.sdk.commons.api import async_bulk from rubrix.client.sdk.commons.errors import RubrixClientError from rubrix.client.sdk.datasets import api as datasets_api from rubrix.client.sdk.datasets.models import CopyDatasetRequest, TaskType @@ -73,7 +73,7 @@ TokenClassificationBulkData, TokenClassificationQuery, ) -from rubrix.client.sdk.users.api import whoami +from rubrix.client.sdk.users import api as users_api from rubrix.client.sdk.users.models import User from rubrix.utils import setup_loop_in_thread @@ -102,12 +102,6 @@ def log(self, *args, **kwargs) -> Future: self.__log_internal__(self.__api__, *args, **kwargs), self.__loop__ ) - def __del__(self): - self.__loop__.stop() - - del self.__loop__ - del self.__thread__ - class Api: # Larger sizes will trigger a warning @@ -147,13 +141,17 @@ def __init__( self._client: AuthenticatedClient = AuthenticatedClient( base_url=api_url, token=api_key, timeout=timeout ) - self._user: User = whoami(client=self._client) + self._user: User = users_api.whoami(client=self._client) if workspace is not None: self.set_workspace(workspace) self._agent = _RubrixLogAgent(self) + def __del__(self): + del self._client + del self._agent + @property def client(self): """The underlying authenticated client""" diff --git a/src/rubrix/client/sdk/client.py b/src/rubrix/client/sdk/client.py index 487d65c686..6d529053eb 100644 --- a/src/rubrix/client/sdk/client.py +++ b/src/rubrix/client/sdk/client.py @@ -23,6 +23,7 @@ @dataclasses.dataclass class _ClientCommonDefaults: + __httpx__: httpx.Client = dataclasses.field(default=None, init=False, compare=False) cookies: Dict[str, str] = dataclasses.field(default_factory=dict) headers: Dict[str, str] = dataclasses.field(default_factory=dict) @@ -43,45 +44,50 @@ def get_timeout(self) -> float: class _Client: base_url: str - -@dataclasses.dataclass -class _AuthenticatedClient(_Client): - token: str - def __post_init__(self): self.base_url = self.base_url.strip() if self.base_url.endswith("/"): self.base_url = self.base_url[:-1] +@dataclasses.dataclass +class _AuthenticatedClient(_Client): + token: str + + @dataclasses.dataclass class Client(_ClientCommonDefaults, _Client): + def __post_init__(self): + super().__post_init__() + self.__httpx__ = httpx.Client( + base_url=self.base_url, + headers=self.get_headers(), + cookies=self.get_cookies(), + timeout=self.get_timeout(), + ) + + def __del__(self): + del self.__httpx__ + def __hash__(self): return hash(self.base_url) def get(self, path: str, *args, **kwargs): path = self._normalize_path(path) - url = f"{self.base_url}/{path}" - response = httpx.get( - url=url, + response = self.__httpx__.get( + url=path, headers=self.get_headers(), - cookies=self.get_cookies(), - timeout=self.get_timeout(), *args, **kwargs, ) - return build_raw_response(response).parsed def post(self, path: str, *args, **kwargs): path = self._normalize_path(path) - url = f"{self.base_url}/{path}" - response = httpx.post( - url=url, + response = self.__httpx__.post( + url=path, headers=self.get_headers(), - cookies=self.get_cookies(), - timeout=self.get_timeout(), *args, **kwargs, ) @@ -89,18 +95,23 @@ def post(self, path: str, *args, **kwargs): def put(self, path: str, *args, **kwargs): path = self._normalize_path(path) - url = f"{self.base_url}/{path}" - - response = httpx.put( - url=url, + response = self.__httpx__.put( + url=path, headers=self.get_headers(), - cookies=self.get_cookies(), - timeout=self.get_timeout(), *args, **kwargs, ) return build_raw_response(response).parsed + def stream(self, path: str, *args, **kwargs): + return self.__httpx__.stream( + url=path, + headers=self.get_headers(), + timeout=None, # Avoid timeouts. TODO: Improve the logic + *args, + **kwargs, + ) + @staticmethod def _normalize_path(path: str) -> str: path = path.strip() diff --git a/src/rubrix/client/sdk/text2text/api.py b/src/rubrix/client/sdk/text2text/api.py index b013229844..29ecd114f3 100644 --- a/src/rubrix/client/sdk/text2text/api.py +++ b/src/rubrix/client/sdk/text2text/api.py @@ -28,14 +28,11 @@ def data( request: Optional[Text2TextQuery] = None, limit: Optional[int] = None, ) -> Response[Union[List[Text2TextRecord], HTTPValidationError, ErrorMessage]]: - url = "{}/api/datasets/{name}/Text2Text/data".format(client.base_url, name=name) + path = f"/api/datasets/{name}/Text2Text/data" - with httpx.stream( + with client.stream( method="POST", - url=url, - headers=client.get_headers(), - cookies=client.get_cookies(), - timeout=None, + path=path, params={"limit": limit} if limit else None, json=request.dict() if request else {}, ) as response: diff --git a/src/rubrix/client/sdk/text_classification/api.py b/src/rubrix/client/sdk/text_classification/api.py index c9d87f853d..62974f7721 100644 --- a/src/rubrix/client/sdk/text_classification/api.py +++ b/src/rubrix/client/sdk/text_classification/api.py @@ -34,16 +34,12 @@ def data( request: Optional[TextClassificationQuery] = None, limit: Optional[int] = None, ) -> Response[Union[List[TextClassificationRecord], HTTPValidationError, ErrorMessage]]: - url = "{}/api/datasets/{name}/TextClassification/data".format( - client.base_url, name=name - ) - with httpx.stream( + path = f"/api/datasets/{name}/TextClassification/data" + + with client.stream( method="POST", - url=url, - headers=client.get_headers(), - cookies=client.get_cookies(), - timeout=None, + path=path, params={"limit": limit} if limit else None, json=request.dict() if request else {}, ) as response: diff --git a/src/rubrix/client/sdk/token_classification/api.py b/src/rubrix/client/sdk/token_classification/api.py index 093d996caf..4af9c2ac2b 100644 --- a/src/rubrix/client/sdk/token_classification/api.py +++ b/src/rubrix/client/sdk/token_classification/api.py @@ -34,16 +34,11 @@ def data( ) -> Response[ Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage] ]: - url = "{}/api/datasets/{name}/TokenClassification/data".format( - client.base_url, name=name - ) + path = f"/api/datasets/{name}/TokenClassification/data" - with httpx.stream( + with client.stream( + path=path, method="POST", - url=url, - headers=client.get_headers(), - cookies=client.get_cookies(), - timeout=None, params={"limit": limit} if limit else None, json=request.dict() if request else {}, ) as response: diff --git a/src/rubrix/client/sdk/users/api.py b/src/rubrix/client/sdk/users/api.py index 21ae70adbb..27168ad62a 100644 --- a/src/rubrix/client/sdk/users/api.py +++ b/src/rubrix/client/sdk/users/api.py @@ -6,16 +6,5 @@ def whoami(client: AuthenticatedClient) -> User: - url = "{}/api/me".format(client.base_url) - - response = httpx.get( - url=url, - headers=client.get_headers(), - cookies=client.get_cookies(), - timeout=client.get_timeout(), - ) - - if response.status_code == 200: - return User(**response.json()) - - handle_response_error(response, msg="Invalid credentials") + response = client.get("/api/me") + return User(**response) diff --git a/src/rubrix/utils.py b/src/rubrix/utils.py index 500004d669..bf8e47c3a2 100644 --- a/src/rubrix/utils.py +++ b/src/rubrix/utils.py @@ -148,14 +148,28 @@ def limit_value_length(data: Any, max_length: int) -> Any: return data +__LOOP__, __THREAD__ = None, None + + def setup_loop_in_thread() -> Tuple[asyncio.AbstractEventLoop, threading.Thread]: """Sets up a new asyncio event loop in a new thread, and runs it forever. Returns: A tuple containing the event loop and the thread. """ - loop = asyncio.new_event_loop() - thread = threading.Thread(target=loop.run_forever, daemon=True) - thread.start() - return loop, thread + def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + global __LOOP__ + global __THREAD__ + + if not (__LOOP__ and __THREAD__): + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=start_background_loop, args=(loop,), daemon=True + ) + thread.start() + __LOOP__, __THREAD__ = loop, thread + return __LOOP__, __THREAD__ diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py index e7e8627368..a62e324337 100644 --- a/tests/client/sdk/conftest.py +++ b/tests/client/sdk/conftest.py @@ -56,9 +56,13 @@ def helpers(): return Helpers() -@pytest.fixture(scope="session") -def sdk_client(): - return AuthenticatedClient(base_url="http://localhost:6900", token=DEFAULT_API_KEY) +@pytest.fixture +def sdk_client(mocked_client, monkeypatch): + client = AuthenticatedClient( + base_url="http://localhost:6900", token=DEFAULT_API_KEY + ) + monkeypatch.setattr(client, "__httpx__", mocked_client) + return client @pytest.fixture diff --git a/tests/client/sdk/text2text/test_api.py b/tests/client/sdk/text2text/test_api.py index 402e6d8ef5..ccbda4314c 100644 --- a/tests/client/sdk/text2text/test_api.py +++ b/tests/client/sdk/text2text/test_api.py @@ -20,12 +20,7 @@ @pytest.mark.parametrize("limit,expected", [(None, 3), (2, 2)]) -def test_data( - limit, mocked_client, expected, sdk_client, bulk_text2text_data, monkeypatch -): - # TODO: Not sure how to test the streaming part of the response here - monkeypatch.setattr(httpx, "stream", mocked_client.stream) - +def test_data(limit, mocked_client, expected, sdk_client, bulk_text2text_data): dataset_name = "test_dataset" mocked_client.delete(f"/api/datasets/{dataset_name}") mocked_client.post( diff --git a/tests/client/sdk/text_classification/test_api.py b/tests/client/sdk/text_classification/test_api.py index 9d8e8bb6ac..ff0e88ae1b 100644 --- a/tests/client/sdk/text_classification/test_api.py +++ b/tests/client/sdk/text_classification/test_api.py @@ -23,9 +23,6 @@ def test_data( mocked_client, limit, expected, bulk_textclass_data, sdk_client, monkeypatch ): - # TODO: Not sure how to test the streaming part of the response here - monkeypatch.setattr(httpx, "stream", mocked_client.stream) - dataset_name = "test_dataset" mocked_client.delete(f"/api/datasets/{dataset_name}") mocked_client.post( diff --git a/tests/client/sdk/token_classification/test_api.py b/tests/client/sdk/token_classification/test_api.py index e88aa46acb..42e1d2316a 100644 --- a/tests/client/sdk/token_classification/test_api.py +++ b/tests/client/sdk/token_classification/test_api.py @@ -23,8 +23,6 @@ def test_data( mocked_client, limit, expected, sdk_client, bulk_tokenclass_data, monkeypatch ): - # TODO: Not sure how to test the streaming part of the response here - monkeypatch.setattr(httpx, "stream", mocked_client.stream) dataset_name = "test_dataset" mocked_client.delete(f"/api/datasets/{dataset_name}") diff --git a/tests/client/sdk/users/test_api.py b/tests/client/sdk/users/test_api.py index eabc332dfc..2c77c92fee 100644 --- a/tests/client/sdk/users/test_api.py +++ b/tests/client/sdk/users/test_api.py @@ -8,19 +8,18 @@ from rubrix.client.sdk.users.models import User -def test_whoami(mocked_client): - sdk_client = AuthenticatedClient( - base_url="http://localhost:6900", token=DEFAULT_API_KEY - ) +def test_whoami(mocked_client, sdk_client): user = whoami(client=sdk_client) assert isinstance(user, User) -def test_whoami_with_auth_error(mocked_client): +def test_whoami_with_auth_error(monkeypatch, mocked_client): with pytest.raises(UnauthorizedApiError): - whoami( - AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey") + sdk_client = AuthenticatedClient( + base_url="http://localhost:6900", token="wrong-apikey" ) + monkeypatch.setattr(sdk_client, "__httpx__", mocked_client) + whoami(sdk_client) def test_whoami_with_connection_error(): diff --git a/tests/client/test_api.py b/tests/client/test_api.py index dc453cf8ce..a81b69e1f5 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -18,7 +18,6 @@ import datasets import httpx -import pandas import pandas as pd import pytest @@ -34,10 +33,11 @@ UnauthorizedApiError, ValidationApiError, ) +from rubrix.client.sdk.users import api as users_api +from rubrix.client.sdk.users.models import User from rubrix.server.apis.v0.models.text_classification import ( TextClassificationSearchResults, ) -from rubrix.server.security import auth from tests.server.test_api import create_some_data_for_text_classification @@ -48,12 +48,10 @@ def mock_response_200(monkeypatch): It will return a 200 status code, emulating the correct login. """ - def mock_get(url, *args, **kwargs): - if "/api/me" in url: - return httpx.Response(status_code=200, json={"username": "booohh"}) - return httpx.Response(status_code=200) + def mock_get(*args, **kwargs): + return User(username="booohh") - monkeypatch.setattr(httpx, "get", mock_get) + monkeypatch.setattr(users_api, "whoami", mock_get) @pytest.fixture @@ -64,9 +62,9 @@ def mock_response_500(monkeypatch): """ def mock_get(*args, **kwargs): - return httpx.Response(status_code=500) + raise GenericApiError("Mock error") - monkeypatch.setattr(httpx, "get", mock_get) + monkeypatch.setattr(users_api, "whoami", mock_get) @pytest.fixture @@ -76,16 +74,14 @@ def mock_response_token_401(monkeypatch): It will return a 401 status code, emulating an invalid credentials error when using tokens to log in. Iterable structure to be able to pass the first 200 status code check """ - response_200 = httpx.Response(status_code=200) - response_401 = httpx.Response(status_code=401) def mock_get(*args, **kwargs): if kwargs["url"] == "fake_url/api/me": - return response_401 + raise UnauthorizedApiError() elif kwargs["url"] == "fake_url/api/docs/spec.json": - return response_200 + return User(username="booohh") - monkeypatch.setattr(httpx, "get", mock_get) + monkeypatch.setattr(users_api, "whoami", mock_get) def test_init_correct(mock_response_200): @@ -94,10 +90,10 @@ def test_init_correct(mock_response_200): It checks if the _client created is a RubrixClient object. """ - api.init() - assert api.__ACTIVE_API__._client == AuthenticatedClient( + assert api.active_api()._client == AuthenticatedClient( base_url="http://localhost:6900", token="rubrix.apikey", timeout=60.0 ) + assert api.__ACTIVE_API__._user == api.User(username="booohh") api.init(api_url="mock_url", api_key="mock_key", workspace="mock_ws", timeout=42) @@ -109,28 +105,6 @@ def test_init_correct(mock_response_200): ) -def test_init_incorrect(mock_response_500): - """Testing incorrect default initalization - - It checks an Exception is raised with the correct message. - """ - - with pytest.raises( - Exception, - match="Rubrix server returned an error with http status: 500\nError details: \[\{'response': None\}\]", - ): - api.init() - - -def test_init_token_auth_fail(mock_response_token_401): - """Testing initalization with failed authentication - - It checks an Exception is raised with the correct message. - """ - with pytest.raises(UnauthorizedApiError): - api.init(api_url="fake_url", api_key="422") - - def test_init_evironment_url(mock_response_200, monkeypatch): """Testing initalization with api_url provided via environment variable diff --git a/tests/client/test_init.py b/tests/client/test_init.py new file mode 100644 index 0000000000..7077e37a1d --- /dev/null +++ b/tests/client/test_init.py @@ -0,0 +1,17 @@ +from rubrix.client import api + + +def test_resource_leaking_with_several_inits(mocked_client): + dataset = "test_resource_leaking_with_several_inits" + api.delete(dataset) + + for i in range(0, 1000): + api.init() + + for i in range(0, 10): + api.init() + api.log( + api.TextClassificationRecord(text="The text"), name=dataset, verbose=False + ) + + assert len(api.load(dataset)) == 10 diff --git a/tests/conftest.py b/tests/conftest.py index e936cee6c3..e707b1d7b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,30 +1,52 @@ import httpx import pytest from _pytest.logging import LogCaptureFixture -from loguru import logger + +from rubrix.client.sdk.users import api as users_api + +try: + from loguru import logger +except ModuleNotFoundError: + logger = None from starlette.testclient import TestClient from rubrix import app +from rubrix.client.api import active_api from tests.helpers import SecuredClient @pytest.fixture def mocked_client(monkeypatch) -> SecuredClient: + with TestClient(app, raise_server_exceptions=False) as _client: - client = SecuredClient(_client) + client_ = SecuredClient(_client) + + real_whoami = users_api.whoami + + def whoami_mocked(client): + monkeypatch.setattr(client, "__httpx__", client_) + return real_whoami(client) + + monkeypatch.setattr(users_api, "whoami", whoami_mocked) + + monkeypatch.setattr(httpx, "post", client_.post) + monkeypatch.setattr(httpx.AsyncClient, "post", client_.post_async) + monkeypatch.setattr(httpx, "get", client_.get) + monkeypatch.setattr(httpx, "delete", client_.delete) + monkeypatch.setattr(httpx, "put", client_.put) + monkeypatch.setattr(httpx, "stream", client_.stream) - monkeypatch.setattr(httpx, "post", client.post) - monkeypatch.setattr(httpx.AsyncClient, "post", client.post_async) - monkeypatch.setattr(httpx, "get", client.get) - monkeypatch.setattr(httpx, "delete", client.delete) - monkeypatch.setattr(httpx, "put", client.put) - monkeypatch.setattr(httpx, "stream", client.stream) + rb_api = active_api() + monkeypatch.setattr(rb_api._client, "__httpx__", client_) - yield client + yield client_ @pytest.fixture def caplog(caplog: LogCaptureFixture): - handler_id = logger.add(caplog.handler, format="{message}") - yield caplog - logger.remove(handler_id) + if not logger: + yield caplog + else: + handler_id = logger.add(caplog.handler, format="{message}") + yield caplog + logger.remove(handler_id) diff --git a/tests/helpers.py b/tests/helpers.py index 0e7f53e0dc..f2ec048f48 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -3,8 +3,8 @@ from fastapi import FastAPI from starlette.testclient import TestClient -import rubrix from rubrix._constants import API_KEY_HEADER_NAME +from rubrix.client.api import active_api from rubrix.server.security import auth from rubrix.server.security.auth_provider.local.settings import settings @@ -20,17 +20,18 @@ def fastpi_app(self) -> FastAPI: def add_workspaces_to_rubrix_user(self, workspaces: List[str]): rubrix_user = auth.users.__dao__.__users__["rubrix"] - workspaces = workspaces or [] - workspaces.extend(rubrix_user.workspaces or []) - rubrix_user.workspaces = workspaces + rubrix_user.workspaces.extend(workspaces or []) - rubrix.init() + rb_api = active_api() + rb_api._user = rubrix_user def reset_rubrix_workspaces(self): rubrix_user = auth.users.__dao__.__users__["rubrix"] - rubrix_user.workspaces = None + rubrix_user.workspaces = ["", "rubrix"] - rubrix.init() + rb_api = active_api() + rb_api._user = rubrix_user + rb_api.set_workspace("rubrix") def delete(self, *args, **kwargs): request_headers = kwargs.pop("headers", {})