diff --git a/ai21/http_client.py b/ai21/http_client.py index 7bb1343f..834869fa 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -149,7 +149,10 @@ def _init_client(self, client: Optional[httpx.Client]) -> httpx.Client: if client is not None: return client - return _requests_retry_session(retries=self._num_retries) if self._apply_retry_policy else httpx.Client() + if self._apply_retry_policy: + return httpx.Client(transport=_requests_retry_session(retries=self._num_retries)) + + return httpx.Client() def add_headers(self, headers: Dict[str, Any]) -> None: self._headers.update(headers) diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index d564aeac..6d90317d 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -1,3 +1,8 @@ +import json +from unittest.mock import patch + +import httpx + from ai21 import AI21Client from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk, ChoicesChunk, ChoiceDelta from ai21.models import RoleType @@ -10,6 +15,16 @@ ), ] +_BAD_HTTPX_REQUEST = httpx.Request(method="POST", url="http://test_url") +_BAD_HTTPX_RESPONSE = httpx.Response(status_code=500, request=_BAD_HTTPX_REQUEST) +_EXPECTED_MESSAGE_CONTENT = { + "id": "test", + "choices": [ + {"index": 0, "message": {"content": "test", "role": "assistant"}, "finish_reason": None, "logprobs": None} + ], + "usage": {"prompt_tokens": 1, "total_tokens": 2, "completion_tokens": 1}, +} + def test_chat_completion(): messages = _MESSAGES @@ -29,6 +44,30 @@ def test_chat_completion(): assert response.choices[0].message.role +def test_chat_completion_when_num_retries_is_over_1__should_retry(): + num_retries = 3 + + with patch.object( + httpx.Client, + "send", + side_effect=[ + _BAD_HTTPX_RESPONSE, + _BAD_HTTPX_RESPONSE, + httpx.Response(status_code=200, content=json.dumps(_EXPECTED_MESSAGE_CONTENT)), + ], + ) as mock_send: + messages = _MESSAGES + + client = AI21Client(num_retries=num_retries) + response = client.chat.completions.create( + model=_MODEL, + messages=messages, + ) + + assert isinstance(response, ChatCompletionResponse) + assert mock_send.call_count == num_retries + + def test_chat_completion__with_n_param__should_return_n_choices(): messages = _MESSAGES n = 3