Skip to content

Commit

Permalink
fix(Client): reusing the inner httpx client (#1640)
Browse files Browse the repository at this point in the history
Closes #1646

(cherry picked from commit f9a80cc)
  • Loading branch information
frascuchon committed Aug 18, 2022
1 parent 660e575 commit ce6b8d8
Show file tree
Hide file tree
Showing 16 changed files with 156 additions and 149 deletions.
18 changes: 8 additions & 10 deletions src/rubrix/client/api.py
Expand Up @@ -20,7 +20,7 @@
from asyncio import Future
from functools import wraps
from inspect import signature
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

from tqdm.auto import tqdm

Expand All @@ -47,7 +47,7 @@
TokenClassificationRecord,
)
from rubrix.client.sdk.client import AuthenticatedClient
from rubrix.client.sdk.commons.api import async_bulk, bulk
from rubrix.client.sdk.commons.api import async_bulk
from rubrix.client.sdk.commons.errors import RubrixClientError
from rubrix.client.sdk.datasets import api as datasets_api
from rubrix.client.sdk.datasets.models import CopyDatasetRequest, TaskType
Expand All @@ -73,7 +73,7 @@
TokenClassificationBulkData,
TokenClassificationQuery,
)
from rubrix.client.sdk.users.api import whoami
from rubrix.client.sdk.users import api as users_api
from rubrix.client.sdk.users.models import User
from rubrix.utils import setup_loop_in_thread

Expand Down Expand Up @@ -102,12 +102,6 @@ def log(self, *args, **kwargs) -> Future:
self.__log_internal__(self.__api__, *args, **kwargs), self.__loop__
)

def __del__(self):
self.__loop__.stop()

del self.__loop__
del self.__thread__


class Api:
# Larger sizes will trigger a warning
Expand Down Expand Up @@ -147,13 +141,17 @@ def __init__(
self._client: AuthenticatedClient = AuthenticatedClient(
base_url=api_url, token=api_key, timeout=timeout
)
self._user: User = whoami(client=self._client)
self._user: User = users_api.whoami(client=self._client)

if workspace is not None:
self.set_workspace(workspace)

self._agent = _RubrixLogAgent(self)

def __del__(self):
del self._client
del self._agent

@property
def client(self):
"""The underlying authenticated client"""
Expand Down
55 changes: 33 additions & 22 deletions src/rubrix/client/sdk/client.py
Expand Up @@ -23,6 +23,7 @@

@dataclasses.dataclass
class _ClientCommonDefaults:
__httpx__: httpx.Client = dataclasses.field(default=None, init=False, compare=False)

cookies: Dict[str, str] = dataclasses.field(default_factory=dict)
headers: Dict[str, str] = dataclasses.field(default_factory=dict)
Expand All @@ -43,64 +44,74 @@ def get_timeout(self) -> float:
class _Client:
base_url: str


@dataclasses.dataclass
class _AuthenticatedClient(_Client):
token: str

def __post_init__(self):
self.base_url = self.base_url.strip()
if self.base_url.endswith("/"):
self.base_url = self.base_url[:-1]


@dataclasses.dataclass
class _AuthenticatedClient(_Client):
token: str


@dataclasses.dataclass
class Client(_ClientCommonDefaults, _Client):
def __post_init__(self):
super().__post_init__()
self.__httpx__ = httpx.Client(
base_url=self.base_url,
headers=self.get_headers(),
cookies=self.get_cookies(),
timeout=self.get_timeout(),
)

def __del__(self):
del self.__httpx__

def __hash__(self):
return hash(self.base_url)

def get(self, path: str, *args, **kwargs):
path = self._normalize_path(path)
url = f"{self.base_url}/{path}"
response = httpx.get(
url=url,
response = self.__httpx__.get(
url=path,
headers=self.get_headers(),
cookies=self.get_cookies(),
timeout=self.get_timeout(),
*args,
**kwargs,
)

return build_raw_response(response).parsed

def post(self, path: str, *args, **kwargs):
path = self._normalize_path(path)
url = f"{self.base_url}/{path}"

response = httpx.post(
url=url,
response = self.__httpx__.post(
url=path,
headers=self.get_headers(),
cookies=self.get_cookies(),
timeout=self.get_timeout(),
*args,
**kwargs,
)
return build_raw_response(response).parsed

def put(self, path: str, *args, **kwargs):
path = self._normalize_path(path)
url = f"{self.base_url}/{path}"

response = httpx.put(
url=url,
response = self.__httpx__.put(
url=path,
headers=self.get_headers(),
cookies=self.get_cookies(),
timeout=self.get_timeout(),
*args,
**kwargs,
)
return build_raw_response(response).parsed

def stream(self, path: str, *args, **kwargs):
return self.__httpx__.stream(
url=path,
headers=self.get_headers(),
timeout=None, # Avoid timeouts. TODO: Improve the logic
*args,
**kwargs,
)

@staticmethod
def _normalize_path(path: str) -> str:
path = path.strip()
Expand Down
9 changes: 3 additions & 6 deletions src/rubrix/client/sdk/text2text/api.py
Expand Up @@ -28,14 +28,11 @@ def data(
request: Optional[Text2TextQuery] = None,
limit: Optional[int] = None,
) -> Response[Union[List[Text2TextRecord], HTTPValidationError, ErrorMessage]]:
url = "{}/api/datasets/{name}/Text2Text/data".format(client.base_url, name=name)
path = f"/api/datasets/{name}/Text2Text/data"

with httpx.stream(
with client.stream(
method="POST",
url=url,
headers=client.get_headers(),
cookies=client.get_cookies(),
timeout=None,
path=path,
params={"limit": limit} if limit else None,
json=request.dict() if request else {},
) as response:
Expand Down
12 changes: 4 additions & 8 deletions src/rubrix/client/sdk/text_classification/api.py
Expand Up @@ -34,16 +34,12 @@ def data(
request: Optional[TextClassificationQuery] = None,
limit: Optional[int] = None,
) -> Response[Union[List[TextClassificationRecord], HTTPValidationError, ErrorMessage]]:
url = "{}/api/datasets/{name}/TextClassification/data".format(
client.base_url, name=name
)

with httpx.stream(
path = f"/api/datasets/{name}/TextClassification/data"

with client.stream(
method="POST",
url=url,
headers=client.get_headers(),
cookies=client.get_cookies(),
timeout=None,
path=path,
params={"limit": limit} if limit else None,
json=request.dict() if request else {},
) as response:
Expand Down
11 changes: 3 additions & 8 deletions src/rubrix/client/sdk/token_classification/api.py
Expand Up @@ -34,16 +34,11 @@ def data(
) -> Response[
Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage]
]:
url = "{}/api/datasets/{name}/TokenClassification/data".format(
client.base_url, name=name
)
path = f"/api/datasets/{name}/TokenClassification/data"

with httpx.stream(
with client.stream(
path=path,
method="POST",
url=url,
headers=client.get_headers(),
cookies=client.get_cookies(),
timeout=None,
params={"limit": limit} if limit else None,
json=request.dict() if request else {},
) as response:
Expand Down
15 changes: 2 additions & 13 deletions src/rubrix/client/sdk/users/api.py
Expand Up @@ -6,16 +6,5 @@


def whoami(client: AuthenticatedClient) -> User:
url = "{}/api/me".format(client.base_url)

response = httpx.get(
url=url,
headers=client.get_headers(),
cookies=client.get_cookies(),
timeout=client.get_timeout(),
)

if response.status_code == 200:
return User(**response.json())

handle_response_error(response, msg="Invalid credentials")
response = client.get("/api/me")
return User(**response)
22 changes: 18 additions & 4 deletions src/rubrix/utils.py
Expand Up @@ -148,14 +148,28 @@ def limit_value_length(data: Any, max_length: int) -> Any:
return data


__LOOP__, __THREAD__ = None, None


def setup_loop_in_thread() -> Tuple[asyncio.AbstractEventLoop, threading.Thread]:
"""Sets up a new asyncio event loop in a new thread, and runs it forever.
Returns:
A tuple containing the event loop and the thread.
"""
loop = asyncio.new_event_loop()
thread = threading.Thread(target=loop.run_forever, daemon=True)
thread.start()

return loop, thread
def start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()

global __LOOP__
global __THREAD__

if not (__LOOP__ and __THREAD__):
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=start_background_loop, args=(loop,), daemon=True
)
thread.start()
__LOOP__, __THREAD__ = loop, thread
return __LOOP__, __THREAD__
10 changes: 7 additions & 3 deletions tests/client/sdk/conftest.py
Expand Up @@ -56,9 +56,13 @@ def helpers():
return Helpers()


@pytest.fixture(scope="session")
def sdk_client():
return AuthenticatedClient(base_url="http://localhost:6900", token=DEFAULT_API_KEY)
@pytest.fixture
def sdk_client(mocked_client, monkeypatch):
client = AuthenticatedClient(
base_url="http://localhost:6900", token=DEFAULT_API_KEY
)
monkeypatch.setattr(client, "__httpx__", mocked_client)
return client


@pytest.fixture
Expand Down
7 changes: 1 addition & 6 deletions tests/client/sdk/text2text/test_api.py
Expand Up @@ -20,12 +20,7 @@


@pytest.mark.parametrize("limit,expected", [(None, 3), (2, 2)])
def test_data(
limit, mocked_client, expected, sdk_client, bulk_text2text_data, monkeypatch
):
# TODO: Not sure how to test the streaming part of the response here
monkeypatch.setattr(httpx, "stream", mocked_client.stream)

def test_data(limit, mocked_client, expected, sdk_client, bulk_text2text_data):
dataset_name = "test_dataset"
mocked_client.delete(f"/api/datasets/{dataset_name}")
mocked_client.post(
Expand Down
3 changes: 0 additions & 3 deletions tests/client/sdk/text_classification/test_api.py
Expand Up @@ -23,9 +23,6 @@
def test_data(
mocked_client, limit, expected, bulk_textclass_data, sdk_client, monkeypatch
):
# TODO: Not sure how to test the streaming part of the response here
monkeypatch.setattr(httpx, "stream", mocked_client.stream)

dataset_name = "test_dataset"
mocked_client.delete(f"/api/datasets/{dataset_name}")
mocked_client.post(
Expand Down
2 changes: 0 additions & 2 deletions tests/client/sdk/token_classification/test_api.py
Expand Up @@ -23,8 +23,6 @@
def test_data(
mocked_client, limit, expected, sdk_client, bulk_tokenclass_data, monkeypatch
):
# TODO: Not sure how to test the streaming part of the response here
monkeypatch.setattr(httpx, "stream", mocked_client.stream)

dataset_name = "test_dataset"
mocked_client.delete(f"/api/datasets/{dataset_name}")
Expand Down
13 changes: 6 additions & 7 deletions tests/client/sdk/users/test_api.py
Expand Up @@ -8,19 +8,18 @@
from rubrix.client.sdk.users.models import User


def test_whoami(mocked_client):
sdk_client = AuthenticatedClient(
base_url="http://localhost:6900", token=DEFAULT_API_KEY
)
def test_whoami(mocked_client, sdk_client):
user = whoami(client=sdk_client)
assert isinstance(user, User)


def test_whoami_with_auth_error(mocked_client):
def test_whoami_with_auth_error(monkeypatch, mocked_client):
with pytest.raises(UnauthorizedApiError):
whoami(
AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey")
sdk_client = AuthenticatedClient(
base_url="http://localhost:6900", token="wrong-apikey"
)
monkeypatch.setattr(sdk_client, "__httpx__", mocked_client)
whoami(sdk_client)


def test_whoami_with_connection_error():
Expand Down

0 comments on commit ce6b8d8

Please sign in to comment.