From 1e1ca5bde1dca297ac8384d9ad50591b1e46093f Mon Sep 17 00:00:00 2001 From: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Date: Mon, 13 May 2024 10:52:39 +0300 Subject: [PATCH 01/14] feat: Add httpx support (#111) * feat: Added httpx to pyproject * feat: httpx instead of requests * feat: Removed requests * fix: not given * fix: setup * feat: Added tenacity for retry * fix: conftest * test: Added tests * fix: Rename * fix: Modified test * fix: CR * fix: request --- .../studio/resources/studio_library.py | 77 +++++----- ai21/http_client.py | 133 ++++++++++-------- poetry.lock | 108 +++++++++++++- pyproject.toml | 3 +- setup.py | 2 +- .../clients/studio/conftest.py | 2 +- tests/unittests/conftest.py | 6 +- tests/unittests/test_ai21_http_client.py | 18 +-- tests/unittests/test_http_client.py | 39 +++++ 9 files changed, 275 insertions(+), 113 deletions(-) create mode 100644 tests/unittests/test_http_client.py diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index 3288c7c7..8b6e6622 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,8 +1,11 @@ +from __future__ import annotations from typing import Optional, List from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import FileResponse, LibraryAnswerResponse, LibrarySearchResponse +from ai21.types import NotGiven, NOT_GIVEN +from ai21.utils.typing import remove_not_given class StudioLibrary(StudioResource): @@ -22,14 +25,14 @@ def create( self, file_path: str, *, - path: Optional[str] = None, - labels: Optional[List[str]] = None, - public_url: Optional[str] = None, + path: Optional[str] | NotGiven = NOT_GIVEN, + labels: Optional[List[str]] | NotGiven = NOT_GIVEN, + public_url: Optional[str] | NotGiven = NOT_GIVEN, **kwargs, ) -> str: url = f"{self._client.get_base_url()}/{self._module_name}" files = {"file": open(file_path, "rb")} - body = {"path": path, "labels": labels, "publicUrl": public_url, **kwargs} + body = remove_not_given({"path": path, "labels": labels, "publicUrl": public_url, **kwargs}) raw_response = self._post(url=url, files=files, body=body) @@ -44,12 +47,12 @@ def get(self, file_id: str) -> FileResponse: def list( self, *, - offset: Optional[int] = None, - limit: Optional[int] = None, + offset: Optional[int] | NotGiven = NOT_GIVEN, + limit: Optional[int] | NotGiven = NOT_GIVEN, **kwargs, ) -> List[FileResponse]: url = f"{self._client.get_base_url()}/{self._module_name}" - params = {"offset": offset, "limit": limit} + params = remove_not_given({"offset": offset, "limit": limit}) raw_response = self._get(url=url, params=params) return [FileResponse.from_dict(file) for file in raw_response] @@ -58,16 +61,18 @@ def update( self, file_id: str, *, - public_url: Optional[str] = None, - labels: Optional[List[str]] = None, + public_url: Optional[str] | NotGiven = NOT_GIVEN, + labels: Optional[List[str]] | NotGiven = NOT_GIVEN, **kwargs, ) -> None: url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}" - body = { - "publicUrl": public_url, - "labels": labels, - **kwargs, - } + body = remove_not_given( + { + "publicUrl": public_url, + "labels": labels, + **kwargs, + } + ) self._put(url=url, body=body) def delete(self, file_id: str) -> None: @@ -82,19 +87,21 @@ def create( self, query: str, *, - path: Optional[str] = None, - field_ids: Optional[List[str]] = None, - max_segments: Optional[int] = None, + path: Optional[str] | NotGiven = NOT_GIVEN, + field_ids: Optional[List[str]] | NotGiven = NOT_GIVEN, + max_segments: Optional[int] | NotGiven = NOT_GIVEN, **kwargs, ) -> LibrarySearchResponse: url = f"{self._client.get_base_url()}/{self._module_name}" - body = { - "query": query, - "path": path, - "fieldIds": field_ids, - "maxSegments": max_segments, - **kwargs, - } + body = remove_not_given( + { + "query": query, + "path": path, + "fieldIds": field_ids, + "maxSegments": max_segments, + **kwargs, + } + ) raw_response = self._post(url=url, body=body) return LibrarySearchResponse.from_dict(raw_response) @@ -106,18 +113,20 @@ def create( self, question: str, *, - path: Optional[str] = None, - field_ids: Optional[List[str]] = None, - labels: Optional[List[str]] = None, + path: Optional[str] | NotGiven = NOT_GIVEN, + field_ids: Optional[List[str]] | NotGiven = NOT_GIVEN, + labels: Optional[List[str]] | NotGiven = NOT_GIVEN, **kwargs, ) -> LibraryAnswerResponse: url = f"{self._client.get_base_url()}/{self._module_name}" - body = { - "question": question, - "path": path, - "fieldIds": field_ids, - "labels": labels, - **kwargs, - } + body = remove_not_given( + { + "question": question, + "path": path, + "fieldIds": field_ids, + "labels": labels, + **kwargs, + } + ) raw_response = self._post(url=url, body=body) return LibraryAnswerResponse.from_dict(raw_response) diff --git a/ai21/http_client.py b/ai21/http_client.py index 229c5156..e199cb1b 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,8 +1,9 @@ import json from typing import Optional, Dict, Any, BinaryIO -import requests -from requests.adapters import HTTPAdapter, Retry, RetryError +import httpx +from httpx import ConnectError +from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential, RetryError from ai21.errors import ( BadRequest, @@ -17,8 +18,8 @@ DEFAULT_TIMEOUT_SEC = 300 DEFAULT_NUM_RETRIES = 0 +TIME_BETWEEN_RETRIES = 1 RETRY_BACK_OFF_FACTOR = 0.5 -TIME_BETWEEN_RETRIES = 1000 RETRY_ERROR_CODES = (408, 429, 500, 503) RETRY_METHOD_WHITELIST = ["GET", "POST", "PUT"] @@ -39,25 +40,16 @@ def handle_non_success_response(status_code: int, response_text: str): raise AI21APIError(status_code, details=response_text) -def requests_retry_session(session, retries=0): - retry = Retry( - total=retries, - read=retries, - connect=retries, - backoff_factor=RETRY_BACK_OFF_FACTOR, - status_forcelist=RETRY_ERROR_CODES, - allowed_methods=frozenset(RETRY_METHOD_WHITELIST), +def _requests_retry_session(retries: int) -> httpx.HTTPTransport: + return httpx.HTTPTransport( + retries=retries, ) - adapter = HTTPAdapter(max_retries=retry) - session.mount("https://", adapter) - session.mount("http://", adapter) - return session class HttpClient: def __init__( self, - session: Optional[requests.Session] = None, + client: Optional[httpx.Client] = None, timeout_sec: int = None, num_retries: int = None, headers: Dict = None, @@ -66,7 +58,18 @@ def __init__( self._num_retries = num_retries or DEFAULT_NUM_RETRIES self._headers = headers or {} self._apply_retry_policy = self._num_retries > 0 - self._session = self._init_session(session) + self._client = self._init_client(client) + + # Since we can't use the retry decorator on a method of a class as we can't access class attributes, + # we have to wrap the method in a function + self._request = retry( + wait=wait_exponential(multiplier=RETRY_BACK_OFF_FACTOR, min=TIME_BETWEEN_RETRIES), + retry=retry_if_result(self._should_retry), + stop=stop_after_attempt(self._num_retries), + )(self._request) + + def _should_retry(self, response: httpx.Response) -> bool: + return response.status_code in RETRY_ERROR_CODES and response.request.method in RETRY_METHOD_WHITELIST def execute_http_request( self, @@ -75,66 +78,72 @@ def execute_http_request( params: Optional[Dict] = None, files: Optional[Dict[str, BinaryIO]] = None, ): - timeout = self._timeout_sec - headers = self._headers - data = json.dumps(params).encode() - logger.debug(f"Calling {method} {url} {headers} {data}") - try: - if method == "GET": - response = self._session.request( - method=method, - url=url, - headers=headers, - timeout=timeout, - params=params, - ) - elif files is not None: - if method != "POST": - raise ValueError( - f"execute_http_request supports only POST for files upload, but {method} was supplied instead" - ) - if "Content-Type" in headers: - headers.pop( - "Content-Type" - ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload - response = self._session.request( - method=method, - url=url, - headers=headers, - data=params, - files=files, - timeout=timeout, - ) + response = self._request(files=files, method=method, params=params, url=url) + except RetryError as retry_error: + last_attempt = retry_error.last_attempt + + if last_attempt.failed: + raise last_attempt.exception() else: - response = self._session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) - except ConnectionError as connection_error: + response = last_attempt.result() + + except ConnectError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error - except RetryError as retry_error: - logger.error( - f"Calling {method} {url} failed with RetryError after {self._num_retries} attempts: {retry_error}" - ) - raise retry_error except Exception as exception: logger.error(f"Calling {method} {url} failed with Exception: {exception}") raise exception - if response.status_code != 200: + if response.status_code != httpx.codes.OK: logger.error(f"Calling {method} {url} failed with a non-200 response code: {response.status_code}") handle_non_success_response(response.status_code, response.text) return response.json() - def _init_session(self, session: Optional[requests.Session]) -> requests.Session: - if session is not None: - return session + def _request( + self, files: Optional[Dict[str, BinaryIO]], method: str, params: Optional[Dict], url: str + ) -> httpx.Response: + timeout = self._timeout_sec + headers = self._headers + logger.debug(f"Calling {method} {url} {headers} {params}") + + if method == "GET": + return self._client.request( + method=method, + url=url, + headers=headers, + timeout=timeout, + params=params, + ) - return ( - requests_retry_session(requests.Session(), retries=self._num_retries) - if self._apply_retry_policy - else requests.Session() + if files is not None: + if method != "POST": + raise ValueError( + f"execute_http_request supports only POST for files upload, but {method} was supplied instead" + ) + if "Content-Type" in headers: + headers.pop( + "Content-Type" + ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload + data = params + else: + data = json.dumps(params).encode() if params else None + + return self._client.request( + method=method, + url=url, + headers=headers, + data=data, + timeout=timeout, + files=files, ) + def _init_client(self, client: Optional[httpx.Client]) -> httpx.Client: + if client is not None: + return client + + return _requests_retry_session(retries=self._num_retries) if self._apply_retry_policy else httpx.Client() + def add_headers(self, headers: Dict[str, Any]) -> None: self._headers.update(headers) diff --git a/poetry.lock b/poetry.lock index 5b4e0c38..9409cfcc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "ai21-tokenizer" @@ -29,6 +29,28 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} +[[package]] +name = "anyio" +version = "4.3.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, + {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] + [[package]] name = "black" version = "24.3.0" @@ -392,6 +414,62 @@ gitdb = ">=4.0.1,<5" [package.extras] test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "sumtypes"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.5" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.26.0)"] + +[[package]] +name = "httpx" +version = "0.27.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "huggingface-hub" version = "0.22.2" @@ -1375,6 +1453,32 @@ files = [ {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"}, ] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + +[[package]] +name = "tenacity" +version = "8.3.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"}, + {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tokenizers" version = "0.15.2" @@ -1624,4 +1728,4 @@ aws = ["boto3"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "915b491d362897a01bb63a9fcc57b19915319a8f6d24e9157444520bed4c5b87" +content-hash = "c6c474d713d3660255aade619131c42d2c57ba74d992b7e20be02275601a5b48" diff --git a/pyproject.toml b/pyproject.toml index 9546e1e8..473cdac0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,11 +56,12 @@ packages = [ [tool.poetry.dependencies] python = "^3.8" -requests = "^2.31.0" ai21-tokenizer = "^0.9.0" boto3 = { version = "^1.28.82", optional = true } dataclasses-json = "^0.6.3" typing-extensions = "^4.9.0" +httpx = "^0.27.0" +tenacity = "^8.3.0" [tool.poetry.group.dev.dependencies] diff --git a/setup.py b/setup.py index 2ba17979..e65e5028 100755 --- a/setup.py +++ b/setup.py @@ -21,6 +21,6 @@ packages=find_packages(exclude=["tests", "tests.*"]), keywords=["python", "sdk", "ai", "ai21", "jurassic", "ai21-python", "llm"], install_requires=[ - "requests", + "httpx", ], ) diff --git a/tests/integration_tests/clients/studio/conftest.py b/tests/integration_tests/clients/studio/conftest.py index 19da28b3..c1ea9400 100644 --- a/tests/integration_tests/clients/studio/conftest.py +++ b/tests/integration_tests/clients/studio/conftest.py @@ -42,7 +42,7 @@ def file_in_library(): # Delete any file that might be in the library due to failed tests files = client.library.files.list() for file in files: - _delete_uploaded_file(file.file_id) + _delete_uploaded_file(client, file.file_id) file_id = client.library.files.create(file_path=LIBRARY_FILE_TO_UPLOAD, labels=DEFAULT_LABELS) _wait_for_file_to_process(client, file_id) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index 02e4d467..68c4efb5 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -1,5 +1,5 @@ import pytest -import requests +import httpx @pytest.fixture @@ -8,5 +8,5 @@ def dummy_api_host() -> str: @pytest.fixture -def mock_requests_session(mocker) -> requests.Session: - return mocker.Mock(spec=requests.Session) +def mock_httpx_client(mocker) -> httpx.Client: + return mocker.Mock(spec=httpx.Client) diff --git a/tests/unittests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py index 86cde2ee..1d730f1a 100644 --- a/tests/unittests/test_ai21_http_client.py +++ b/tests/unittests/test_ai21_http_client.py @@ -1,8 +1,8 @@ import platform from typing import Optional +import httpx import pytest -import requests from ai21.ai21_http_client import AI21HTTPClient from ai21.http_client import HttpClient @@ -94,12 +94,12 @@ def test__execute_http_request__( params, headers, dummy_api_host: str, - mock_requests_session: requests.Session, + mock_httpx_client: httpx.Client, ): response_json = {"test_key": "test_value"} - mock_requests_session.request.return_value = MockResponse(response_json, 200) + mock_httpx_client.request.return_value = MockResponse(response_json, 200) - http_client = HttpClient(session=mock_requests_session) + http_client = HttpClient(client=mock_httpx_client) client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") response = client.execute_http_request(**params) @@ -107,7 +107,7 @@ def test__execute_http_request__( if "files" in params: # We split it because when calling requests with "files", "params" is turned into "data" - mock_requests_session.request.assert_called_once_with( + mock_httpx_client.request.assert_called_once_with( timeout=300, headers=headers, files=params["files"], @@ -116,18 +116,18 @@ def test__execute_http_request__( method=params["method"], ) else: - mock_requests_session.request.assert_called_once_with(timeout=300, headers=headers, **params) + mock_httpx_client.request.assert_called_once_with(timeout=300, headers=headers, **params) def test__execute_http_request__when_files_with_put_method__should_raise_value_error( dummy_api_host: str, - mock_requests_session: requests.Session, + mock_httpx_client: httpx.Client, ): response_json = {"test_key": "test_value"} - http_client = HttpClient(session=mock_requests_session) + http_client = HttpClient(client=mock_httpx_client) client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") - mock_requests_session.request.return_value = MockResponse(response_json, 200) + mock_httpx_client.request.return_value = MockResponse(response_json, 200) with pytest.raises(ValueError): params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} client.execute_http_request(**params) diff --git a/tests/unittests/test_http_client.py b/tests/unittests/test_http_client.py new file mode 100644 index 00000000..73347f09 --- /dev/null +++ b/tests/unittests/test_http_client.py @@ -0,0 +1,39 @@ +import pytest +from unittest.mock import Mock +from urllib.request import Request + +import httpx + +from ai21.errors import ServiceUnavailable +from ai21.http_client import HttpClient, RETRY_ERROR_CODES + +_METHOD = "GET" +_URL = "http://test_url" + + +def test__execute_http_request__when_retry_error_code_once__should_retry_and_succeed(mock_httpx_client: Mock) -> None: + request = Request(method=_METHOD, url=_URL) + retries = 3 + mock_httpx_client.request.side_effect = [ + httpx.Response(status_code=429, request=request), + httpx.Response(status_code=200, request=request, json={"test_key": "test_value"}), + ] + + client = HttpClient(client=mock_httpx_client, num_retries=retries) + client.execute_http_request(method=_METHOD, url=_URL) + assert mock_httpx_client.request.call_count == retries - 1 + + +def test__execute_http_request__when_retry_error__should_retry_and_stop(mock_httpx_client: Mock) -> None: + request = Request(method=_METHOD, url=_URL) + retries = len(RETRY_ERROR_CODES) + + mock_httpx_client.request.side_effect = [ + httpx.Response(status_code=status_code, request=request) for status_code in RETRY_ERROR_CODES + ] + + client = HttpClient(client=mock_httpx_client, num_retries=retries) + with pytest.raises(ServiceUnavailable): + client.execute_http_request(method=_METHOD, url=_URL) + + assert mock_httpx_client.request.call_count == retries From 6afcddc2442e476e0dea3d01065a88326b280f2f Mon Sep 17 00:00:00 2001 From: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Date: Thu, 16 May 2024 16:38:10 +0300 Subject: [PATCH 02/14] feat: Stream support jamba (#114) * feat: Added httpx to pyproject * feat: httpx instead of requests * feat: Removed requests * fix: not given * fix: setup * feat: Added tenacity for retry * fix: conftest * test: Added tests * fix: Rename * fix: Modified test * feat: stream support (unfinished) * feat: Added tenacity for retry * test: Added tests * fix: Rename * feat: stream support (unfinished) * fix: single request creation * feat: Support stream_cls * fix: passed response_cls * fix: Removed unnecessary json_to_response * fix: imports * fix: tests * fix: imports * fix: reponse parse * fix: Added two examples to tests * fix: sagemaker tests * test: Added stream tests * fix: comment out failing test * fix: CR * fix: Removed code * docs: Added readme for streaming * fix: condition * docs: readme --- README.md | 24 ++ ai21/ai21_http_client.py | 7 +- ai21/clients/common/chat_base.py | 3 - ai21/clients/common/completion_base.py | 5 +- ai21/clients/common/dataset_base.py | 5 - ai21/clients/common/segmentation_base.py | 3 - .../common/summarize_by_segment_base.py | 3 - .../studio/resources/chat/chat_completions.py | 52 +++- .../clients/studio/resources/studio_answer.py | 4 +- ai21/clients/studio/resources/studio_chat.py | 3 +- .../studio/resources/studio_completion.py | 2 +- .../studio/resources/studio_custom_model.py | 8 +- .../studio/resources/studio_dataset.py | 7 +- ai21/clients/studio/resources/studio_embed.py | 3 +- ai21/clients/studio/resources/studio_gec.py | 3 +- .../studio/resources/studio_improvements.py | 3 +- .../studio/resources/studio_library.py | 17 +- .../studio/resources/studio_paraphrase.py | 3 +- .../studio/resources/studio_resource.py | 65 ++++- .../studio/resources/studio_segmentation.py | 3 +- .../studio/resources/studio_summarize.py | 3 +- .../resources/studio_summarize_by_segment.py | 4 +- ai21/errors.py | 6 + ai21/http_client.py | 30 ++- ai21/models/chat/chat_completion_chunk.py | 27 ++ ai21/services/sagemaker.py | 4 +- ai21/stream.py | 67 +++++ ai21/types.py | 12 +- ai21/utils/typing.py | 12 +- .../studio/chat/stream_chat_completions.py | 21 ++ .../integration_tests/clients/test_studio.py | 4 + .../clients/studio/resources/conftest.py | 248 ++++++++++-------- .../studio/resources/test_studio_resources.py | 30 ++- tests/unittests/conftest.py | 5 + tests/unittests/services/test_sagemaker.py | 18 +- tests/unittests/test_ai21_http_client.py | 22 +- tests/unittests/test_http_client.py | 8 +- tests/unittests/test_stream.py | 64 +++++ 38 files changed, 588 insertions(+), 220 deletions(-) create mode 100644 ai21/models/chat/chat_completion_chunk.py create mode 100644 ai21/stream.py create mode 100644 examples/studio/chat/stream_chat_completions.py create mode 100644 tests/unittests/test_stream.py diff --git a/README.md b/README.md index 98e1c244..019d79bd 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,30 @@ For a more detailed example, see the completion [examples](examples/studio/compl --- +## Streaming + +We currently support streaming for the Chat Completions API in Jamba. + +```python +from ai21 import AI21Client +from ai21.models.chat import ChatMessage + +messages = [ChatMessage(content="What is the meaning of life?", role="user")] + +client = AI21Client() + +response = client.chat.completions.create( + messages=messages, + model="jamba-instruct-preview", + stream=True, +) +for chunk in response: + print(chunk.choices[0].delta.content, end="") + +``` + +--- + ## TSMs AI21 Studio's Task-Specific Models offer a range of powerful tools. These models have been specifically designed for their respective tasks and provide high-quality results while optimizing efficiency. diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 61c80a93..1a3b334f 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,6 +1,8 @@ import platform from typing import Optional, Dict, Any, BinaryIO +import httpx + from ai21.errors import MissingApiKeyError from ai21.http_client import HttpClient from ai21.version import VERSION @@ -76,9 +78,10 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, + stream: bool = False, files: Optional[Dict[str, BinaryIO]] = None, - ): - return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) + ) -> httpx.Response: + return self._http_client.execute_http_request(method=method, url=url, params=params, files=files, stream=stream) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index c89a4e12..4ade9690 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -53,9 +53,6 @@ def create( def completions(self) -> ChatCompletions: pass - def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse: - return ChatResponse.from_dict(json) - def _create_body( self, model: str, diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index b16c3990..49e5b5ea 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Dict, Any +from typing import List, Dict from ai21.models import Penalty, CompletionsResponse from ai21.types import NOT_GIVEN, NotGiven @@ -55,9 +55,6 @@ def create( """ pass - def _json_to_response(self, json: Dict[str, Any]) -> CompletionsResponse: - return CompletionsResponse.from_dict(json) - def _create_body( self, model: str, diff --git a/ai21/clients/common/dataset_base.py b/ai21/clients/common/dataset_base.py index dd5982ab..97bae64a 100644 --- a/ai21/clients/common/dataset_base.py +++ b/ai21/clients/common/dataset_base.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict -from ai21.models import DatasetResponse - class Dataset(ABC): _module_name = "dataset" @@ -40,9 +38,6 @@ def list(self): def get(self, dataset_pid: str): pass - def _json_to_response(self, json: Dict[str, Any]) -> DatasetResponse: - return DatasetResponse.from_dict(json) - def _create_body( self, dataset_name: str, diff --git a/ai21/clients/common/segmentation_base.py b/ai21/clients/common/segmentation_base.py index 6805e70f..074ba8c8 100644 --- a/ai21/clients/common/segmentation_base.py +++ b/ai21/clients/common/segmentation_base.py @@ -18,8 +18,5 @@ def create(self, source: str, source_type: DocumentType, **kwargs) -> Segmentati """ pass - def _json_to_response(self, json: Dict[str, Any]) -> SegmentationResponse: - return SegmentationResponse.from_dict(json) - def _create_body(self, source: str, source_type: str, **kwargs) -> Dict[str, Any]: return {"source": source, "sourceType": source_type, **kwargs} diff --git a/ai21/clients/common/summarize_by_segment_base.py b/ai21/clients/common/summarize_by_segment_base.py index e5163ca9..516d0ebe 100644 --- a/ai21/clients/common/summarize_by_segment_base.py +++ b/ai21/clients/common/summarize_by_segment_base.py @@ -26,9 +26,6 @@ def create( """ pass - def _json_to_response(self, json: Dict[str, Any]) -> SummarizeBySegmentResponse: - return SummarizeBySegmentResponse.from_dict(json) - def _create_body( self, source: str, diff --git a/ai21/clients/studio/resources/chat/chat_completions.py b/ai21/clients/studio/resources/chat/chat_completions.py index 3516f04e..65cc2d40 100644 --- a/ai21/clients/studio/resources/chat/chat_completions.py +++ b/ai21/clients/studio/resources/chat/chat_completions.py @@ -1,10 +1,12 @@ from __future__ import annotations -from typing import List, Optional, Union, Any, Dict +from typing import List, Optional, Union, Any, Dict, Literal, overload from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models.chat import ChatMessage, ChatCompletionResponse from ai21.models import ChatMessage as J2ChatMessage +from ai21.models.chat import ChatMessage, ChatCompletionResponse +from ai21.models.chat.chat_completion_chunk import ChatCompletionChunk +from ai21.stream import Stream from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given @@ -14,6 +16,7 @@ class ChatCompletions(StudioResource): _module_name = "chat/completions" + @overload def create( self, model: str, @@ -23,8 +26,38 @@ def create( top_p: float | NotGiven = NOT_GIVEN, stop: str | List[str] | NotGiven = NOT_GIVEN, n: int | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> ChatCompletionResponse: + pass + + @overload + def create( + self, + model: str, + messages: List[ChatMessage], + stream: Literal[True], + max_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + stop: str | List[str] | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + **kwargs: Any, + ) -> Stream[ChatCompletionChunk]: + pass + + def create( + self, + model: str, + messages: List[ChatMessage], + max_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + stop: str | List[str] | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + **kwargs: Any, + ) -> ChatCompletionResponse | Stream[ChatCompletionChunk]: if any(isinstance(item, J2ChatMessage) for item in messages): raise ValueError( "Please use the ChatMessage class from ai21.models.chat" @@ -39,12 +72,18 @@ def create( max_tokens=max_tokens, top_p=top_p, n=n, + stream=stream or False, **kwargs, ) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post( + url=url, + body=body, + stream=stream or False, + stream_cls=Stream[ChatCompletionChunk], + response_cls=ChatCompletionResponse, + ) def _create_body( self, @@ -55,6 +94,7 @@ def _create_body( top_p: Optional[float] | NotGiven, stop: Optional[Union[str, List[str]]] | NotGiven, n: Optional[int] | NotGiven, + stream: Literal[False] | Literal[True] | NotGiven, **kwargs: Any, ) -> Dict[str, Any]: return remove_not_given( @@ -66,9 +106,7 @@ def _create_body( "topP": top_p, "stop": stop, "n": n, + "stream": stream, **kwargs, } ) - - def _json_to_response(self, json: Dict[str, Any]) -> ChatCompletionResponse: - return ChatCompletionResponse.from_dict(json) diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index 0831e37e..0da4ae01 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -14,6 +14,4 @@ def create( body = self._create_body(context=context, question=question, **kwargs) - response = self._post(url=url, body=body) - - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=AnswerResponse) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index ab6b9cda..daccea1d 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -49,8 +49,7 @@ def create( **kwargs, ) url = f"{self._client.get_base_url()}/{model}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=ChatResponse) @property def completions(self) -> ChatCompletions: diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 8594105b..d2b9cdc7 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -53,4 +53,4 @@ def create( logit_bias=logit_bias, **kwargs, ) - return self._json_to_response(self._post(url=url, body=body)) + return self._post(url=url, body=body, response_cls=CompletionsResponse) diff --git a/ai21/clients/studio/resources/studio_custom_model.py b/ai21/clients/studio/resources/studio_custom_model.py index 1e1fb9b4..61e351d1 100644 --- a/ai21/clients/studio/resources/studio_custom_model.py +++ b/ai21/clients/studio/resources/studio_custom_model.py @@ -25,14 +25,12 @@ def create( num_epochs=num_epochs, **kwargs, ) - self._post(url=url, body=body) + self._post(url=url, body=body, response_cls=None) def list(self) -> List[CustomBaseModelResponse]: url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._get(url=url) - - return [self._json_to_response(r) for r in response] + return self._get(url=url, response_cls=List[CustomBaseModelResponse]) def get(self, resource_id: str) -> CustomBaseModelResponse: url = f"{self._client.get_base_url()}/{self._module_name}/{resource_id}" - return self._json_to_response(self._get(url=url)) + return self._get(url=url, response_cls=CustomBaseModelResponse) diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py index 0620fa00..10f26ddd 100644 --- a/ai21/clients/studio/resources/studio_dataset.py +++ b/ai21/clients/studio/resources/studio_dataset.py @@ -33,14 +33,11 @@ def create( ) def list(self) -> List[DatasetResponse]: - response = self._get(url=self._base_url()) - return [self._json_to_response(r) for r in response] + return self._get(url=self._base_url(), response_cls=List[DatasetResponse]) def get(self, dataset_pid: str) -> DatasetResponse: url = f"{self._base_url()}/{dataset_pid}" - response = self._get(url=url) - - return self._json_to_response(response) + return self._get(url=url, response_cls=DatasetResponse) def _base_url(self) -> str: return f"{self._client.get_base_url()}/{self._module_name}" diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py index c6af637b..e45b6269 100644 --- a/ai21/clients/studio/resources/studio_embed.py +++ b/ai21/clients/studio/resources/studio_embed.py @@ -9,6 +9,5 @@ class StudioEmbed(StudioResource, Embed): def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(texts=texts, type=type, **kwargs) - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=EmbedResponse) diff --git a/ai21/clients/studio/resources/studio_gec.py b/ai21/clients/studio/resources/studio_gec.py index e5b2b2ff..a8752c9c 100644 --- a/ai21/clients/studio/resources/studio_gec.py +++ b/ai21/clients/studio/resources/studio_gec.py @@ -7,6 +7,5 @@ class StudioGEC(StudioResource, GEC): def create(self, text: str, **kwargs) -> GECResponse: body = self._create_body(text=text, **kwargs) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=GECResponse) diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index f767684c..88ea996c 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -13,6 +13,5 @@ def create(self, text: str, types: List[ImprovementType], **kwargs) -> Improveme url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(text=text, types=types, **kwargs) - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=ImprovementsResponse) diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index 8b6e6622..e177df94 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Optional, List + from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import FileResponse, LibraryAnswerResponse, LibrarySearchResponse @@ -34,15 +35,14 @@ def create( files = {"file": open(file_path, "rb")} body = remove_not_given({"path": path, "labels": labels, "publicUrl": public_url, **kwargs}) - raw_response = self._post(url=url, files=files, body=body) + raw_response = self._post(url=url, files=files, body=body, response_cls=dict) return raw_response["fileId"] def get(self, file_id: str) -> FileResponse: url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}" - raw_response = self._get(url=url) - return FileResponse.from_dict(raw_response) + return self._get(url=url, response_cls=FileResponse) def list( self, @@ -53,9 +53,8 @@ def list( ) -> List[FileResponse]: url = f"{self._client.get_base_url()}/{self._module_name}" params = remove_not_given({"offset": offset, "limit": limit}) - raw_response = self._get(url=url, params=params) - return [FileResponse.from_dict(file) for file in raw_response] + return self._get(url=url, params=params, response_cls=List[FileResponse]) def update( self, @@ -102,8 +101,8 @@ def create( **kwargs, } ) - raw_response = self._post(url=url, body=body) - return LibrarySearchResponse.from_dict(raw_response) + + return self._post(url=url, body=body, response_cls=LibrarySearchResponse) class LibraryAnswer(StudioResource): @@ -128,5 +127,5 @@ def create( **kwargs, } ) - raw_response = self._post(url=url, body=body) - return LibraryAnswerResponse.from_dict(raw_response) + + return self._post(url=url, body=body, response_cls=LibraryAnswerResponse) diff --git a/ai21/clients/studio/resources/studio_paraphrase.py b/ai21/clients/studio/resources/studio_paraphrase.py index 5def5608..25764e46 100644 --- a/ai21/clients/studio/resources/studio_paraphrase.py +++ b/ai21/clients/studio/resources/studio_paraphrase.py @@ -23,6 +23,5 @@ def create( **kwargs, ) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=ParaphraseResponse) diff --git a/ai21/clients/studio/resources/studio_resource.py b/ai21/clients/studio/resources/studio_resource.py index 8ece396e..8f4be991 100644 --- a/ai21/clients/studio/resources/studio_resource.py +++ b/ai21/clients/studio/resources/studio_resource.py @@ -1,10 +1,16 @@ from __future__ import annotations +import json from abc import ABC -from typing import Any, Dict, Optional, BinaryIO +from typing import Any, Dict, Optional, BinaryIO, get_origin + +import httpx from ai21.ai21_http_client import AI21HTTPClient +from ai21.types import ResponseT, StreamT +from ai21.utils.typing import extract_type + class StudioResource(ABC): def __init__(self, client: AI21HTTPClient): @@ -14,23 +20,64 @@ def _post( self, url: str, body: Dict[str, Any], + response_cls: Optional[ResponseT] = None, + stream_cls: Optional[StreamT] = None, + stream: bool = False, files: Optional[Dict[str, BinaryIO]] = None, - ) -> Dict[str, Any]: - return self._client.execute_http_request( + ) -> ResponseT | StreamT: + response = self._client.execute_http_request( method="POST", url=url, + stream=stream, params=body or {}, files=files, ) - def _get(self, url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - return self._client.execute_http_request(method="GET", url=url, params=params or {}) + return self._cast_response(stream=stream, response=response, response_cls=response_cls, stream_cls=stream_cls) + + def _get( + self, url: str, response_cls: Optional[ResponseT] = None, params: Optional[Dict[str, Any]] = None + ) -> ResponseT | StreamT: + response = self._client.execute_http_request(method="GET", url=url, params=params or {}) + return self._cast_response(response=response, response_cls=response_cls) - def _put(self, url: str, body: Dict[str, Any] = None) -> Dict[str, Any]: - return self._client.execute_http_request(method="PUT", url=url, params=body or {}) + def _put( + self, url: str, response_cls: Optional[ResponseT] = None, body: Dict[str, Any] = None + ) -> ResponseT | StreamT: + response = self._client.execute_http_request(method="PUT", url=url, params=body or {}) + return self._cast_response(response=response, response_cls=response_cls) - def _delete(self, url: str) -> Dict[str, Any]: - return self._client.execute_http_request( + def _delete(self, url: str, response_cls: Optional[ResponseT] = None) -> ResponseT | StreamT: + response = self._client.execute_http_request( method="DELETE", url=url, ) + return self._cast_response(response=response, response_cls=response_cls) + + def _cast_response( + self, + response: httpx.Response, + response_cls: Optional[ResponseT], + stream_cls: Optional[StreamT] = None, + stream: bool = False, + ) -> ResponseT | StreamT | None: + if stream and stream_cls is not None: + cast_to = extract_type(stream_cls) + return stream_cls(cast_to=cast_to, response=response) + + if response_cls is None: + return None + + if response_cls == dict: + return response.json() + + if response_cls == str: + return json.loads(response.json()) + + origin_type = get_origin(response_cls) + + if origin_type is not None and origin_type == list: + subtype = extract_type(response_cls) + return [subtype.from_dict(item) for item in response.json()] + + return response_cls.from_dict(response.json()) diff --git a/ai21/clients/studio/resources/studio_segmentation.py b/ai21/clients/studio/resources/studio_segmentation.py index e8e44efe..a2aee960 100644 --- a/ai21/clients/studio/resources/studio_segmentation.py +++ b/ai21/clients/studio/resources/studio_segmentation.py @@ -7,6 +7,5 @@ class StudioSegmentation(StudioResource, Segmentation): def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: body = self._create_body(source=source, source_type=source_type.value, **kwargs) url = f"{self._client.get_base_url()}/{self._module_name}" - raw_response = self._post(url=url, body=body) - return self._json_to_response(raw_response) + return self._post(url=url, body=body, response_cls=SegmentationResponse) diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index 7fd84756..6ba4f9fe 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -25,6 +25,5 @@ def create( **kwargs, ) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=SummarizeResponse) diff --git a/ai21/clients/studio/resources/studio_summarize_by_segment.py b/ai21/clients/studio/resources/studio_summarize_by_segment.py index 292dcbaf..abb1705e 100644 --- a/ai21/clients/studio/resources/studio_summarize_by_segment.py +++ b/ai21/clients/studio/resources/studio_summarize_by_segment.py @@ -16,5 +16,5 @@ def create( **kwargs, ) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + + return self._post(url=url, body=body, response_cls=SummarizeBySegmentResponse) diff --git a/ai21/errors.py b/ai21/errors.py index 4a0f8c92..7d5bae8d 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -87,3 +87,9 @@ class EmptyMandatoryListError(AI21Error): def __init__(self, key: str): message = f"Supplied {key} is empty. At least one element should be present in the list" super().__init__(message) + + +class StreamingDecodeError(AI21Error): + def __init__(self, chunk: str): + message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format" + super().__init__(message) diff --git a/ai21/http_client.py b/ai21/http_client.py index e199cb1b..7bb1343f 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -18,8 +18,8 @@ DEFAULT_TIMEOUT_SEC = 300 DEFAULT_NUM_RETRIES = 0 -TIME_BETWEEN_RETRIES = 1 RETRY_BACK_OFF_FACTOR = 0.5 +TIME_BETWEEN_RETRIES = 1 RETRY_ERROR_CODES = (408, 429, 500, 503) RETRY_METHOD_WHITELIST = ["GET", "POST", "PUT"] @@ -76,10 +76,11 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, + stream: bool = False, files: Optional[Dict[str, BinaryIO]] = None, - ): + ) -> httpx.Response: try: - response = self._request(files=files, method=method, params=params, url=url) + response = self._request(files=files, method=method, params=params, url=url, stream=stream) except RetryError as retry_error: last_attempt = retry_error.last_attempt @@ -99,24 +100,27 @@ def execute_http_request( logger.error(f"Calling {method} {url} failed with a non-200 response code: {response.status_code}") handle_non_success_response(response.status_code, response.text) - return response.json() + return response def _request( - self, files: Optional[Dict[str, BinaryIO]], method: str, params: Optional[Dict], url: str + self, + files: Optional[Dict[str, BinaryIO]], + method: str, + params: Optional[Dict], + url: str, + stream: bool, ) -> httpx.Response: timeout = self._timeout_sec headers = self._headers logger.debug(f"Calling {method} {url} {headers} {params}") if method == "GET": - return self._client.request( - method=method, - url=url, - headers=headers, - timeout=timeout, - params=params, + request = self._client.build_request( + method=method, url=url, headers=headers, timeout=timeout, params=params ) + return self._client.send(request=request, stream=stream) + if files is not None: if method != "POST": raise ValueError( @@ -130,7 +134,7 @@ def _request( else: data = json.dumps(params).encode() if params else None - return self._client.request( + request = self._client.build_request( method=method, url=url, headers=headers, @@ -139,6 +143,8 @@ def _request( files=files, ) + return self._client.send(request=request, stream=stream) + def _init_client(self, client: Optional[httpx.Client]) -> httpx.Client: if client is not None: return client diff --git a/ai21/models/chat/chat_completion_chunk.py b/ai21/models/chat/chat_completion_chunk.py new file mode 100644 index 00000000..2a63e63c --- /dev/null +++ b/ai21/models/chat/chat_completion_chunk.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from typing import Optional, List + +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from ai21.models.logprobs import Logprobs +from ai21.models.usage_info import UsageInfo + + +@dataclass +class ChoiceDelta(AI21BaseModelMixin): + content: Optional[str] = None + role: Optional[str] = None + + +@dataclass +class ChoicesChunk(AI21BaseModelMixin): + index: int + delta: ChoiceDelta + logprobs: Optional[Logprobs] = None + finish_reason: Optional[str] = None + + +@dataclass +class ChatCompletionChunk(AI21BaseModelMixin): + id: str + choices: List[ChoicesChunk] + usage: Optional[UsageInfo] = None diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index 627fbe2e..3132cd1a 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -31,7 +31,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE }, ) - arn = response["arn"] + arn = response.json()["arn"] if not arn: raise ModelPackageDoesntExistError(model_name=model_name, region=region, version=version) @@ -53,7 +53,7 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: }, ) - return response["versions"] + return response.json()["versions"] @classmethod def _create_ai21_http_client(cls) -> AI21HTTPClient: diff --git a/ai21/stream.py b/ai21/stream.py new file mode 100644 index 00000000..bd324000 --- /dev/null +++ b/ai21/stream.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import json +from typing import TypeVar, Generic, Iterator, Optional + +import httpx + +from ai21.errors import StreamingDecodeError + +_T = TypeVar("_T") +_SSE_DATA_PREFIX = "data: " +_SSE_DONE_MSG = "[DONE]" + + +class Stream(Generic[_T]): + response: httpx.Response + + def __init__( + self, + *, + cast_to: type[_T], + response: httpx.Response, + ): + self.response = response + self.cast_to = cast_to + self._decoder = _SSEDecoder() + self._iterator = self.__stream__() + + def __next__(self) -> _T: + return self._iterator.__next__() + + def __iter__(self) -> Iterator[_T]: + for item in self._iterator: + yield item + + def __stream__(self) -> Iterator[_T]: + for chunk in self._decoder.iter(self.response.iter_lines()): + if chunk.endswith(_SSE_DONE_MSG): + break + + try: + chunk = json.loads(chunk) + if hasattr(self.cast_to, "from_dict"): + yield self.cast_to.from_dict(chunk) + else: + yield self.cast_to(**chunk) + except json.JSONDecodeError: + raise StreamingDecodeError(chunk) + + +class _SSEDecoder: + def iter(self, iterator: Iterator[str]): + for line in iterator: + line = line.strip() + decoded_line = self._decode(line) + + if decoded_line is not None: + yield decoded_line + + def _decode(self, line: str) -> Optional[str]: + if not line: + return None + + if line.startswith(_SSE_DATA_PREFIX): + return line.strip(_SSE_DATA_PREFIX) + + raise StreamingDecodeError(f"Invalid SSE line: {line}") diff --git a/ai21/types.py b/ai21/types.py index 137c938d..13a7a035 100644 --- a/ai21/types.py +++ b/ai21/types.py @@ -1,4 +1,14 @@ -from typing_extensions import Literal +from typing import Any, Union, List + +import httpx +from typing_extensions import Literal, TypeVar, TYPE_CHECKING +from ai21.stream import Stream + +if TYPE_CHECKING: + from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin # noqa + +ResponseT = TypeVar("_ResponseT", bound=Union["AI21BaseModelMixin", str, httpx.Response, List[Any]]) +StreamT = TypeVar("_StreamT", bound=Stream[Any]) # Sentinel class used until PEP 0661 is accepted diff --git a/ai21/utils/typing.py b/ai21/utils/typing.py index ae77329b..be0244c7 100644 --- a/ai21/utils/typing.py +++ b/ai21/utils/typing.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, get_args, cast from ai21.types import NotGiven @@ -20,3 +20,13 @@ def to_lower_camel_case(snake_str: str) -> str: # with the 'capitalize' method and join them together. camel_string = to_camel_case(snake_str) return snake_str[0].lower() + camel_string[1:] + + +def extract_type(type_to_extract: Any) -> type: + args = get_args(type_to_extract) + try: + return cast(type, args[0]) + except IndexError as err: + raise RuntimeError( + f"Expected type {type_to_extract} to have a type argument at index 0 but it did not" + ) from err diff --git a/examples/studio/chat/stream_chat_completions.py b/examples/studio/chat/stream_chat_completions.py new file mode 100644 index 00000000..fd079962 --- /dev/null +++ b/examples/studio/chat/stream_chat_completions.py @@ -0,0 +1,21 @@ +from ai21 import AI21Client +from ai21.models.chat import ChatMessage + +system = "You're a support engineer in a SaaS company" +messages = [ + ChatMessage(content=system, role="system"), + ChatMessage(content="Hello, I need help with a signup process.", role="user"), + ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), + ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), +] + +client = AI21Client() + +response = client.chat.completions.create( + messages=messages, + model="jamba-instruct-preview", + max_tokens=100, + stream=True, +) +for chunk in response: + print(chunk.choices[0].delta.content, end="") diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index e5500d22..2d2d78f8 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -27,6 +27,8 @@ ("summarize.py",), ("summarize_by_segment.py",), ("tokenization.py",), + ("chat/chat_completions.py",), + # ("chat/stream_chat_completions.py",), # Uncomment when streaming is supported in production # ("custom_model.py", ), # ('custom_model_completion.py', ), # ("dataset.py", ), @@ -45,6 +47,8 @@ "when_summarize__should_return_ok", "when_summarize_by_segment__should_return_ok", "when_tokenization__should_return_ok", + "when_chat_completions__should_return_ok", + # "when_stream_chat_completions__should_return_ok", # "when_custom_model__should_return_ok", # "when_custom_model_completion__should_return_ok", # "when_dataset__should_return_ok", diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 132f95ad..627a2eee 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -1,4 +1,5 @@ import pytest +import httpx from pytest_mock import MockerFixture from ai21.ai21_http_client import AI21HTTPClient @@ -18,39 +19,24 @@ ChatMessage, RoleType, ChatResponse, - ChatOutput, - FinishReason, - Prompt, - Completion, - CompletionData, - CompletionFinishReason, CompletionsResponse, EmbedType, EmbedResponse, - EmbedResult, GECResponse, - Correction, - CorrectionType, ImprovementType, ImprovementsResponse, - Improvement, ParaphraseStyleType, ParaphraseResponse, - Suggestion, DocumentType, SegmentationResponse, SummaryMethod, SummarizeResponse, SummarizeBySegmentResponse, - SegmentSummary, ) from ai21.models.chat import ( ChatMessage as ChatCompletionChatMessage, ChatCompletionResponse, - ChatCompletionResponseChoice, ) -from ai21.models.responses.segmentation_response import Segment -from ai21.models.usage_info import UsageInfo from ai21.utils.typing import to_lower_camel_case @@ -59,9 +45,18 @@ def mock_ai21_studio_client(mocker: MockerFixture) -> AI21HTTPClient: return mocker.MagicMock(spec=AI21HTTPClient) +@pytest.fixture +def mock_successful_httpx_response(mocker: MockerFixture) -> httpx.Response: + mock_httpx_response = mocker.Mock(spec=httpx.Response) + mock_httpx_response.status_code = 200 + + return mock_httpx_response + + def get_studio_answer(): _DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" _DUMMY_QUESTION = "What is the answer?" + json_response = {"id": "some-id", "answer_in_context": True, "answer": "42"} return ( StudioAnswer, @@ -71,7 +66,8 @@ def get_studio_answer(): "context": _DUMMY_CONTEXT, "question": _DUMMY_QUESTION, }, - AnswerResponse(id="some-id", answer_in_context=True, answer="42"), + httpx.Response(status_code=200, json=json_response), + AnswerResponse.from_dict(json_response), ) @@ -85,6 +81,15 @@ def get_studio_chat(): ), ] _DUMMY_SYSTEM = "You're a support engineer in a SaaS company" + json_response = { + "outputs": [ + { + "text": "Hello, I need help with a signup process.", + "role": "user", + "finishReason": {"reason": "dummy_reason", "length": 1, "sequence": "1"}, + } + ] + } return ( StudioChat, @@ -105,15 +110,8 @@ def get_studio_chat(): "presencePenalty": None, "countPenalty": None, }, - ChatResponse( - outputs=[ - ChatOutput( - text="Hello, I need help with a signup process.", - role="user", - finish_reason=FinishReason(reason="dummy_reason", length=1, sequence="1"), - ) - ] - ), + httpx.Response(status_code=200, json=json_response), + ChatResponse.from_dict(json_response), ) @@ -127,6 +125,25 @@ def get_chat_completions(): ), ] _EXPECTED_SERIALIZED_MESSAGES = [message.to_dict() for message in _DUMMY_MESSAGES] + json_response = { + "id": "some-id", + "choices": [ + { + "index": 0, + "message": { + "content": "Hello, I need help with a signup process.", + "role": "user", + }, + "finishReason": "dummy_reason", + "logprobs": None, + } + ], + "usage": { + "promptTokens": 10, + "completionTokens": 20, + "totalTokens": 30, + }, + } return ( ChatCompletions, @@ -135,31 +152,26 @@ def get_chat_completions(): { "model": _DUMMY_MODEL, "messages": _EXPECTED_SERIALIZED_MESSAGES, + "stream": False, }, - ChatCompletionResponse( - id="some-id", - choices=[ - ChatCompletionResponseChoice( - index=0, - message=ChatCompletionChatMessage( - content="Hello, I need help with a signup process.", role=RoleType.USER - ), - finish_reason="dummy_reason", - logprobs=None, - ) - ], - usage=UsageInfo( - prompt_tokens=10, - completion_tokens=20, - total_tokens=30, - ), - ), + httpx.Response(status_code=200, json=json_response), + ChatCompletionResponse.from_dict(json_response), ) def get_studio_completion(**kwargs): _DUMMY_MODEL = "dummy-completion-model" _DUMMY_PROMPT = "dummy-prompt" + json_response = { + "id": "some-id", + "completions": [ + { + "data": {"text": "dummy-completion", "tokens": []}, + "finishReason": {"reason": "dummy_reason", "length": 1}, + } + ], + "prompt": {"text": "dummy-prompt"}, + } return ( StudioCompletion, @@ -170,20 +182,20 @@ def get_studio_completion(**kwargs): "prompt": _DUMMY_PROMPT, **{to_lower_camel_case(k): v for k, v in kwargs.items()}, }, - CompletionsResponse( - id="some-id", - completions=[ - Completion( - data=CompletionData(text="dummy-completion", tokens=[]), - finish_reason=CompletionFinishReason(reason="dummy_reason", length=1), - ) - ], - prompt=Prompt(text="dummy-prompt"), - ), + httpx.Response(status_code=200, json=json_response), + CompletionsResponse.from_dict(json_response), ) def get_studio_embed(): + json_response = { + "id": "some-id", + "results": [ + {"embedding": [1.0, 2.0, 3.0]}, + {"embedding": [4.0, 5.0, 6.0]}, + ], + } + return ( StudioEmbed, {"texts": ["text0", "text1"], "type": EmbedType.QUERY}, @@ -192,18 +204,26 @@ def get_studio_embed(): "texts": ["text0", "text1"], "type": EmbedType.QUERY.value, }, - EmbedResponse( - id="some-id", - results=[ - EmbedResult([1.0, 2.0, 3.0]), - EmbedResult([4.0, 5.0, 6.0]), - ], - ), + httpx.Response(status_code=200, json=json_response), + EmbedResponse.from_dict(json_response), ) def get_studio_gec(): + json_response = { + "id": "some-id", + "corrections": [ + { + "suggestion": "text to fix", + "startIndex": 9, + "endIndex": 10, + "originalText": "text to fi", + "correctionType": "Spelling", + } + ], + } text = "text to fi" + return ( StudioGEC, {"text": text}, @@ -211,24 +231,27 @@ def get_studio_gec(): { "text": text, }, - GECResponse( - id="some-id", - corrections=[ - Correction( - suggestion="text to fix", - start_index=9, - end_index=10, - original_text=text, - correction_type=CorrectionType.SPELLING, - ) - ], - ), + httpx.Response(status_code=200, json=json_response), + GECResponse.from_dict(json_response), ) def get_studio_improvements(): + json_response = { + "id": "some-id", + "improvements": [ + { + "suggestions": ["This text is improved"], + "startIndex": 0, + "endIndex": 15, + "originalText": "text to improve", + "improvementType": "FLUENCY", + } + ], + } text = "text to improve" types = [ImprovementType.FLUENCY] + return ( StudioImprovements, {"text": text, "types": types}, @@ -237,18 +260,8 @@ def get_studio_improvements(): "text": text, "types": types, }, - ImprovementsResponse( - id="some-id", - improvements=[ - Improvement( - suggestions=["This text is improved"], - start_index=0, - end_index=15, - original_text=text, - improvement_type=ImprovementType.FLUENCY, - ) - ], - ), + httpx.Response(status_code=200, json=json_response), + ImprovementsResponse.from_dict(json_response), ) @@ -257,6 +270,15 @@ def get_studio_paraphrase(): style = ParaphraseStyleType.CASUAL start_index = 0 end_index = 10 + json_response = { + "id": "some-id", + "suggestions": [ + { + "text": "This text is paraphrased", + } + ], + } + return ( StudioParaphrase, {"text": text, "style": style, "start_index": start_index, "end_index": end_index}, @@ -267,13 +289,24 @@ def get_studio_paraphrase(): "startIndex": start_index, "endIndex": end_index, }, - ParaphraseResponse(id="some-id", suggestions=[Suggestion(text="This text is paraphrased")]), + httpx.Response(status_code=200, json=json_response), + ParaphraseResponse.from_dict(json_response), ) def get_studio_segmentation(): source = "segmentation text" source_type = DocumentType.TEXT + json_response = { + "id": "some-id", + "segments": [ + { + "segmentText": "This text is segmented", + "segmentType": "segment_type", + } + ], + } + return ( StudioSegmentation, {"source": source, "source_type": source_type}, @@ -282,9 +315,8 @@ def get_studio_segmentation(): "source": source, "sourceType": source_type, }, - SegmentationResponse( - id="some-id", segments=[Segment(segment_text="This text is segmented", segment_type="segment_type")] - ), + httpx.Response(status_code=200, json=json_response), + SegmentationResponse.from_dict(json_response), ) @@ -293,6 +325,11 @@ def get_studio_summarization(): source_type = "TEXT" focus = "text" summary_method = SummaryMethod.FULL_DOCUMENT + json_response = { + "id": "some-id", + "summary": "This text is summarized", + } + return ( StudioSummarize, {"source": source, "source_type": source_type, "focus": focus, "summary_method": summary_method}, @@ -303,10 +340,8 @@ def get_studio_summarization(): "focus": focus, "summaryMethod": summary_method, }, - SummarizeResponse( - id="some-id", - summary="This text is summarized", - ), + httpx.Response(status_code=200, json=json_response), + SummarizeResponse.from_dict(json_response), ) @@ -314,6 +349,20 @@ def get_studio_summarize_by_segment(): source = "text to summarize" source_type = "TEXT" focus = "text" + json_response = { + "id": "some-id", + "segments": [ + { + "summary": "This text is summarized", + "segmentText": "This text is segmented", + "segmentHtml": "", + "segmentType": "segment_type", + "hasSummary": True, + "highlights": [], + } + ], + } + return ( StudioSummarizeBySegment, {"source": source, "source_type": source_type, "focus": focus}, @@ -323,17 +372,6 @@ def get_studio_summarize_by_segment(): "sourceType": source_type, "focus": focus, }, - SummarizeBySegmentResponse( - id="some-id", - segments=[ - SegmentSummary( - summary="This text is summarized", - segment_text="This text is segmented", - segment_type="segment_type", - segment_html=None, - has_summary=True, - highlights=[], - ) - ], - ), + httpx.Response(status_code=200, json=json_response), + SummarizeBySegmentResponse.from_dict(json_response), ) diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index 33c02318..f5ad3196 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -1,7 +1,7 @@ from typing import TypeVar, Callable import pytest - +import httpx from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_resource import StudioResource @@ -43,7 +43,14 @@ class TestStudioResources: "studio_summarization", "studio_summarize_by_segment", ], - argnames=["studio_resource", "function_body", "url_suffix", "expected_body", "expected_response"], + argnames=[ + "studio_resource", + "function_body", + "url_suffix", + "expected_body", + "expected_httpx_response", + "expected_response", + ], argvalues=[ (get_studio_answer()), (get_studio_chat()), @@ -59,16 +66,17 @@ class TestStudioResources: (get_studio_summarize_by_segment()), ], ) - def test__create__should_return_answer_response( + def test__create__should_return_response( self, studio_resource: Callable[[AI21HTTPClient], T], function_body, url_suffix: str, expected_body, - expected_response, + expected_httpx_response, + expected_response: AnswerResponse, mock_ai21_studio_client: AI21HTTPClient, ): - mock_ai21_studio_client.execute_http_request.return_value = expected_response.to_dict() + mock_ai21_studio_client.execute_http_request.return_value = expected_httpx_response mock_ai21_studio_client.get_base_url.return_value = _BASE_URL resource = studio_resource(mock_ai21_studio_client) @@ -82,12 +90,19 @@ def test__create__should_return_answer_response( method="POST", url=f"{_BASE_URL}/{url_suffix}", params=expected_body, + stream=False, files=None, ) - def test__create__when_pass_kwargs__should_pass_to_request(self, mock_ai21_studio_client: AI21HTTPClient): + def test__create__when_pass_kwargs__should_pass_to_request( + self, + mock_ai21_studio_client: AI21HTTPClient, + mock_successful_httpx_response: httpx.Response, + ): expected_answer = AnswerResponse(id="some-id", answer_in_context=True, answer="42") - mock_ai21_studio_client.execute_http_request.return_value = expected_answer.to_dict() + mock_successful_httpx_response.json.return_value = expected_answer.to_dict() + + mock_ai21_studio_client.execute_http_request.return_value = mock_successful_httpx_response mock_ai21_studio_client.get_base_url.return_value = _BASE_URL studio_answer = StudioAnswer(mock_ai21_studio_client) @@ -105,5 +120,6 @@ def test__create__when_pass_kwargs__should_pass_to_request(self, mock_ai21_studi "question": _DUMMY_QUESTION, "some_dummy_kwargs": "some_dummy_value", }, + stream=False, files=None, ) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index 68c4efb5..f8efe8ae 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -10,3 +10,8 @@ def dummy_api_host() -> str: @pytest.fixture def mock_httpx_client(mocker) -> httpx.Client: return mocker.Mock(spec=httpx.Client) + + +@pytest.fixture +def mock_httpx_response(mocker) -> httpx.Response: + return mocker.Mock(spec=httpx.Response) diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index c6f9e165..1e8fdc04 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,3 +1,4 @@ +import httpx import pytest from ai21 import ModelPackageDoesntExistError @@ -8,28 +9,29 @@ class TestSageMakerService: - def test__get_model_package_arn__should_return_model_package_arn(self): - expected_response = { + def test__get_model_package_arn__should_return_model_package_arn(self, mock_httpx_response: httpx.Response): + mock_httpx_response.json.return_value = { "arn": _DUMMY_ARN, "versions": _DUMMY_VERSIONS, } - SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + SageMakerStub.ai21_http_client.execute_http_request.return_value = mock_httpx_response actual_model_package_arn = SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") assert actual_model_package_arn == _DUMMY_ARN - def test__get_model_package_arn__when_no_arn__should_raise_error(self): - SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} + def test__get_model_package_arn__when_no_arn__should_raise_error(self, mock_httpx_response: httpx.Response): + mock_httpx_response.json.return_value = {"arn": []} + SageMakerStub.ai21_http_client.execute_http_request.return_value = mock_httpx_response with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") - def test__list_model_package_versions__should_return_model_package_arn(self): - expected_response = { + def test__list_model_package_versions__should_return_model_package_arn(self, mock_httpx_response: httpx.Response): + mock_httpx_response.json.return_value = { "versions": _DUMMY_VERSIONS, } - SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + SageMakerStub.ai21_http_client.execute_http_request.return_value = mock_httpx_response actual_model_package_arn = SageMakerStub.list_model_package_versions(model_name="j2-mid", region="us-east-1") diff --git a/tests/unittests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py index 1d730f1a..3ae8f6c6 100644 --- a/tests/unittests/test_ai21_http_client.py +++ b/tests/unittests/test_ai21_http_client.py @@ -1,5 +1,7 @@ import platform from typing import Optional +from unittest.mock import Mock +from urllib.request import Request import httpx import pytest @@ -85,7 +87,13 @@ def test__get_base_url(api_host: Optional[str], expected_api_host: str): argvalues=[ ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), ( - {"method": "POST", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}}, + { + "method": "POST", + "url": "test_url", + "params": {"foo": "bar"}, + "stream": False, + "files": {"file": "test_file"}, + }, _EXPECTED_POST_FILE_HEADERS, ), ], @@ -97,17 +105,19 @@ def test__execute_http_request__( mock_httpx_client: httpx.Client, ): response_json = {"test_key": "test_value"} - mock_httpx_client.request.return_value = MockResponse(response_json, 200) + mock_response = Mock(spec=Request) + mock_httpx_client.build_request.return_value = mock_response + mock_httpx_client.send.return_value = MockResponse(response_json, 200) http_client = HttpClient(client=mock_httpx_client) client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") response = client.execute_http_request(**params) - assert response == response_json + assert response.json() == response_json if "files" in params: # We split it because when calling requests with "files", "params" is turned into "data" - mock_httpx_client.request.assert_called_once_with( + mock_httpx_client.build_request.assert_called_once_with( timeout=300, headers=headers, files=params["files"], @@ -116,7 +126,9 @@ def test__execute_http_request__( method=params["method"], ) else: - mock_httpx_client.request.assert_called_once_with(timeout=300, headers=headers, **params) + mock_httpx_client.build_request.assert_called_once_with(timeout=300, headers=headers, **params) + + mock_httpx_client.send.assert_called_once_with(request=mock_response, stream=False) def test__execute_http_request__when_files_with_put_method__should_raise_value_error( diff --git a/tests/unittests/test_http_client.py b/tests/unittests/test_http_client.py index 73347f09..0342c426 100644 --- a/tests/unittests/test_http_client.py +++ b/tests/unittests/test_http_client.py @@ -14,21 +14,21 @@ def test__execute_http_request__when_retry_error_code_once__should_retry_and_succeed(mock_httpx_client: Mock) -> None: request = Request(method=_METHOD, url=_URL) retries = 3 - mock_httpx_client.request.side_effect = [ + mock_httpx_client.send.side_effect = [ httpx.Response(status_code=429, request=request), httpx.Response(status_code=200, request=request, json={"test_key": "test_value"}), ] client = HttpClient(client=mock_httpx_client, num_retries=retries) client.execute_http_request(method=_METHOD, url=_URL) - assert mock_httpx_client.request.call_count == retries - 1 + assert mock_httpx_client.send.call_count == retries - 1 def test__execute_http_request__when_retry_error__should_retry_and_stop(mock_httpx_client: Mock) -> None: request = Request(method=_METHOD, url=_URL) retries = len(RETRY_ERROR_CODES) - mock_httpx_client.request.side_effect = [ + mock_httpx_client.send.side_effect = [ httpx.Response(status_code=status_code, request=request) for status_code in RETRY_ERROR_CODES ] @@ -36,4 +36,4 @@ def test__execute_http_request__when_retry_error__should_retry_and_stop(mock_htt with pytest.raises(ServiceUnavailable): client.execute_http_request(method=_METHOD, url=_URL) - assert mock_httpx_client.request.call_count == retries + assert mock_httpx_client.send.call_count == retries diff --git a/tests/unittests/test_stream.py b/tests/unittests/test_stream.py new file mode 100644 index 00000000..7abd2563 --- /dev/null +++ b/tests/unittests/test_stream.py @@ -0,0 +1,64 @@ +import json +from dataclasses import dataclass +from typing import AsyncIterable + +import httpx +import pytest + +from ai21.errors import StreamingDecodeError +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from ai21.stream import Stream + + +@dataclass +class StubStreamObject(AI21BaseModelMixin): + id: str + name: str + + +def async_byte_stream() -> AsyncIterable[bytes]: + for i in range(10): + data = {"id": f"some-{i}", "name": f"some-name-{i}"} + msg = f"data: {json.dumps(data)}\r\n" + yield msg.encode("utf-8") + + +def async_byte_bad_stream_prefix() -> AsyncIterable[bytes]: + msg = "bad_stream: {}\r\n" + yield msg.encode("utf-8") + + +def async_byte_bad_stream_json_format() -> AsyncIterable[bytes]: + msg = "data: not a json format\r\n" + yield msg.encode("utf-8") + + +def test_stream_object_when_json_string_ok__should_be_ok(): + stream = async_byte_stream() + response = httpx.Response(status_code=200, content=stream) + stream_obj = Stream[StubStreamObject](response=response, cast_to=StubStreamObject) + + for i, chunk in enumerate(stream_obj): + assert isinstance(chunk, StubStreamObject) + assert chunk.name == f"some-name-{i}" + assert chunk.id == f"some-{i}" + + +@pytest.mark.parametrize( + ids=[ + "bad_stream_data_prefix", + "bad_stream_json_format", + ], + argnames=["stream"], + argvalues=[ + (async_byte_bad_stream_prefix(),), + (async_byte_bad_stream_json_format(),), + ], +) +def test_stream_object_when_bad_json__should_raise_error(stream): + response = httpx.Response(status_code=200, content=stream) + stream_obj = Stream[StubStreamObject](response=response, cast_to=StubStreamObject) + + with pytest.raises(StreamingDecodeError): + for _ in stream_obj: + pass From ab1f8a2a874c1958b7a5f3cfff214d433767cef1 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Sun, 19 May 2024 08:24:53 +0300 Subject: [PATCH 03/14] test: Added integration test for streaming --- .../studio/resources/chat/chat_completions.py | 3 +-- ai21/models/chat/__init__.py | 11 +++++++++- .../clients/studio/test_chat_completions.py | 21 +++++++++++++++++-- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/ai21/clients/studio/resources/chat/chat_completions.py b/ai21/clients/studio/resources/chat/chat_completions.py index 65cc2d40..cddc8cb3 100644 --- a/ai21/clients/studio/resources/chat/chat_completions.py +++ b/ai21/clients/studio/resources/chat/chat_completions.py @@ -4,8 +4,7 @@ from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import ChatMessage as J2ChatMessage -from ai21.models.chat import ChatMessage, ChatCompletionResponse -from ai21.models.chat.chat_completion_chunk import ChatCompletionChunk +from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk from ai21.stream import Stream from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given diff --git a/ai21/models/chat/__init__.py b/ai21/models/chat/__init__.py index 0fb4df66..d9332ff1 100644 --- a/ai21/models/chat/__init__.py +++ b/ai21/models/chat/__init__.py @@ -4,5 +4,14 @@ from .chat_completion_response import ChatCompletionResponseChoice from .chat_message import ChatMessage from .role_type import RoleType as RoleType +from .chat_completion_chunk import ChatCompletionChunk, ChoicesChunk, ChoiceDelta -__all__ = ["ChatCompletionResponse", "ChatCompletionResponseChoice", "ChatMessage", "RoleType"] +__all__ = [ + "ChatCompletionResponse", + "ChatCompletionResponseChoice", + "ChatMessage", + "RoleType", + "ChatCompletionChunk", + "ChoicesChunk", + "ChoiceDelta", +] diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index 684f0aa5..d564aeac 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -1,8 +1,7 @@ from ai21 import AI21Client -from ai21.models.chat import ChatMessage, ChatCompletionResponse +from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk, ChoicesChunk, ChoiceDelta from ai21.models import RoleType - _MODEL = "jamba-instruct-preview" _MESSAGES = [ ChatMessage( @@ -50,3 +49,21 @@ def test_chat_completion__with_n_param__should_return_n_choices(): for choice in response.choices: assert choice.message.content assert choice.message.role + + +def test_chat_completion__when_stream__should_return_chunks(): + messages = _MESSAGES + + client = AI21Client() + + response = client.chat.completions.create( + model=_MODEL, + messages=messages, + temperature=0, + stream=True, + ) + + for chunk in response: + assert isinstance(chunk, ChatCompletionChunk) + assert isinstance(chunk.choices[0], ChoicesChunk) + assert isinstance(chunk.choices[0].delta, ChoiceDelta) From 5323220fd408365b818f3c57b17671e020b79b7f Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Sun, 19 May 2024 15:13:26 +0300 Subject: [PATCH 04/14] fix: Added enter and close to stream --- ai21/stream.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ai21/stream.py b/ai21/stream.py index bd324000..de3b01d3 100644 --- a/ai21/stream.py +++ b/ai21/stream.py @@ -1,9 +1,11 @@ from __future__ import annotations import json +from types import TracebackType from typing import TypeVar, Generic, Iterator, Optional import httpx +from typing_extensions import Self from ai21.errors import StreamingDecodeError @@ -47,6 +49,20 @@ def __stream__(self) -> Iterator[_T]: except json.JSONDecodeError: raise StreamingDecodeError(chunk) + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def close(self): + self.response.close() + class _SSEDecoder: def iter(self, iterator: Iterator[str]): From 9fe6710f90b25fdd391792d42ad949ebcbcef413 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Mon, 20 May 2024 12:00:23 +0300 Subject: [PATCH 05/14] fix: Uncomment chat completions test --- tests/integration_tests/clients/test_studio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index 2d2d78f8..42e7abf4 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -28,7 +28,7 @@ ("summarize_by_segment.py",), ("tokenization.py",), ("chat/chat_completions.py",), - # ("chat/stream_chat_completions.py",), # Uncomment when streaming is supported in production + ("chat/stream_chat_completions.py",), # ("custom_model.py", ), # ('custom_model_completion.py', ), # ("dataset.py", ), From cc044f5a3df1f2299d6bab01fe6b7642b08f742d Mon Sep 17 00:00:00 2001 From: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Date: Mon, 13 May 2024 10:52:39 +0300 Subject: [PATCH 06/14] feat: Add httpx support (#111) * feat: Added httpx to pyproject * feat: httpx instead of requests * feat: Removed requests * fix: not given * fix: setup * feat: Added tenacity for retry * fix: conftest * test: Added tests * fix: Rename * fix: Modified test * fix: CR * fix: request --- .../studio/resources/studio_library.py | 77 +++++----- ai21/http_client.py | 133 ++++++++++-------- poetry.lock | 108 +++++++++++++- pyproject.toml | 3 +- setup.py | 2 +- tests/unittests/conftest.py | 6 +- tests/unittests/test_ai21_http_client.py | 18 +-- tests/unittests/test_http_client.py | 39 +++++ 8 files changed, 274 insertions(+), 112 deletions(-) create mode 100644 tests/unittests/test_http_client.py diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index 3288c7c7..8b6e6622 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,8 +1,11 @@ +from __future__ import annotations from typing import Optional, List from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import FileResponse, LibraryAnswerResponse, LibrarySearchResponse +from ai21.types import NotGiven, NOT_GIVEN +from ai21.utils.typing import remove_not_given class StudioLibrary(StudioResource): @@ -22,14 +25,14 @@ def create( self, file_path: str, *, - path: Optional[str] = None, - labels: Optional[List[str]] = None, - public_url: Optional[str] = None, + path: Optional[str] | NotGiven = NOT_GIVEN, + labels: Optional[List[str]] | NotGiven = NOT_GIVEN, + public_url: Optional[str] | NotGiven = NOT_GIVEN, **kwargs, ) -> str: url = f"{self._client.get_base_url()}/{self._module_name}" files = {"file": open(file_path, "rb")} - body = {"path": path, "labels": labels, "publicUrl": public_url, **kwargs} + body = remove_not_given({"path": path, "labels": labels, "publicUrl": public_url, **kwargs}) raw_response = self._post(url=url, files=files, body=body) @@ -44,12 +47,12 @@ def get(self, file_id: str) -> FileResponse: def list( self, *, - offset: Optional[int] = None, - limit: Optional[int] = None, + offset: Optional[int] | NotGiven = NOT_GIVEN, + limit: Optional[int] | NotGiven = NOT_GIVEN, **kwargs, ) -> List[FileResponse]: url = f"{self._client.get_base_url()}/{self._module_name}" - params = {"offset": offset, "limit": limit} + params = remove_not_given({"offset": offset, "limit": limit}) raw_response = self._get(url=url, params=params) return [FileResponse.from_dict(file) for file in raw_response] @@ -58,16 +61,18 @@ def update( self, file_id: str, *, - public_url: Optional[str] = None, - labels: Optional[List[str]] = None, + public_url: Optional[str] | NotGiven = NOT_GIVEN, + labels: Optional[List[str]] | NotGiven = NOT_GIVEN, **kwargs, ) -> None: url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}" - body = { - "publicUrl": public_url, - "labels": labels, - **kwargs, - } + body = remove_not_given( + { + "publicUrl": public_url, + "labels": labels, + **kwargs, + } + ) self._put(url=url, body=body) def delete(self, file_id: str) -> None: @@ -82,19 +87,21 @@ def create( self, query: str, *, - path: Optional[str] = None, - field_ids: Optional[List[str]] = None, - max_segments: Optional[int] = None, + path: Optional[str] | NotGiven = NOT_GIVEN, + field_ids: Optional[List[str]] | NotGiven = NOT_GIVEN, + max_segments: Optional[int] | NotGiven = NOT_GIVEN, **kwargs, ) -> LibrarySearchResponse: url = f"{self._client.get_base_url()}/{self._module_name}" - body = { - "query": query, - "path": path, - "fieldIds": field_ids, - "maxSegments": max_segments, - **kwargs, - } + body = remove_not_given( + { + "query": query, + "path": path, + "fieldIds": field_ids, + "maxSegments": max_segments, + **kwargs, + } + ) raw_response = self._post(url=url, body=body) return LibrarySearchResponse.from_dict(raw_response) @@ -106,18 +113,20 @@ def create( self, question: str, *, - path: Optional[str] = None, - field_ids: Optional[List[str]] = None, - labels: Optional[List[str]] = None, + path: Optional[str] | NotGiven = NOT_GIVEN, + field_ids: Optional[List[str]] | NotGiven = NOT_GIVEN, + labels: Optional[List[str]] | NotGiven = NOT_GIVEN, **kwargs, ) -> LibraryAnswerResponse: url = f"{self._client.get_base_url()}/{self._module_name}" - body = { - "question": question, - "path": path, - "fieldIds": field_ids, - "labels": labels, - **kwargs, - } + body = remove_not_given( + { + "question": question, + "path": path, + "fieldIds": field_ids, + "labels": labels, + **kwargs, + } + ) raw_response = self._post(url=url, body=body) return LibraryAnswerResponse.from_dict(raw_response) diff --git a/ai21/http_client.py b/ai21/http_client.py index 229c5156..e199cb1b 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,8 +1,9 @@ import json from typing import Optional, Dict, Any, BinaryIO -import requests -from requests.adapters import HTTPAdapter, Retry, RetryError +import httpx +from httpx import ConnectError +from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential, RetryError from ai21.errors import ( BadRequest, @@ -17,8 +18,8 @@ DEFAULT_TIMEOUT_SEC = 300 DEFAULT_NUM_RETRIES = 0 +TIME_BETWEEN_RETRIES = 1 RETRY_BACK_OFF_FACTOR = 0.5 -TIME_BETWEEN_RETRIES = 1000 RETRY_ERROR_CODES = (408, 429, 500, 503) RETRY_METHOD_WHITELIST = ["GET", "POST", "PUT"] @@ -39,25 +40,16 @@ def handle_non_success_response(status_code: int, response_text: str): raise AI21APIError(status_code, details=response_text) -def requests_retry_session(session, retries=0): - retry = Retry( - total=retries, - read=retries, - connect=retries, - backoff_factor=RETRY_BACK_OFF_FACTOR, - status_forcelist=RETRY_ERROR_CODES, - allowed_methods=frozenset(RETRY_METHOD_WHITELIST), +def _requests_retry_session(retries: int) -> httpx.HTTPTransport: + return httpx.HTTPTransport( + retries=retries, ) - adapter = HTTPAdapter(max_retries=retry) - session.mount("https://", adapter) - session.mount("http://", adapter) - return session class HttpClient: def __init__( self, - session: Optional[requests.Session] = None, + client: Optional[httpx.Client] = None, timeout_sec: int = None, num_retries: int = None, headers: Dict = None, @@ -66,7 +58,18 @@ def __init__( self._num_retries = num_retries or DEFAULT_NUM_RETRIES self._headers = headers or {} self._apply_retry_policy = self._num_retries > 0 - self._session = self._init_session(session) + self._client = self._init_client(client) + + # Since we can't use the retry decorator on a method of a class as we can't access class attributes, + # we have to wrap the method in a function + self._request = retry( + wait=wait_exponential(multiplier=RETRY_BACK_OFF_FACTOR, min=TIME_BETWEEN_RETRIES), + retry=retry_if_result(self._should_retry), + stop=stop_after_attempt(self._num_retries), + )(self._request) + + def _should_retry(self, response: httpx.Response) -> bool: + return response.status_code in RETRY_ERROR_CODES and response.request.method in RETRY_METHOD_WHITELIST def execute_http_request( self, @@ -75,66 +78,72 @@ def execute_http_request( params: Optional[Dict] = None, files: Optional[Dict[str, BinaryIO]] = None, ): - timeout = self._timeout_sec - headers = self._headers - data = json.dumps(params).encode() - logger.debug(f"Calling {method} {url} {headers} {data}") - try: - if method == "GET": - response = self._session.request( - method=method, - url=url, - headers=headers, - timeout=timeout, - params=params, - ) - elif files is not None: - if method != "POST": - raise ValueError( - f"execute_http_request supports only POST for files upload, but {method} was supplied instead" - ) - if "Content-Type" in headers: - headers.pop( - "Content-Type" - ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload - response = self._session.request( - method=method, - url=url, - headers=headers, - data=params, - files=files, - timeout=timeout, - ) + response = self._request(files=files, method=method, params=params, url=url) + except RetryError as retry_error: + last_attempt = retry_error.last_attempt + + if last_attempt.failed: + raise last_attempt.exception() else: - response = self._session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) - except ConnectionError as connection_error: + response = last_attempt.result() + + except ConnectError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error - except RetryError as retry_error: - logger.error( - f"Calling {method} {url} failed with RetryError after {self._num_retries} attempts: {retry_error}" - ) - raise retry_error except Exception as exception: logger.error(f"Calling {method} {url} failed with Exception: {exception}") raise exception - if response.status_code != 200: + if response.status_code != httpx.codes.OK: logger.error(f"Calling {method} {url} failed with a non-200 response code: {response.status_code}") handle_non_success_response(response.status_code, response.text) return response.json() - def _init_session(self, session: Optional[requests.Session]) -> requests.Session: - if session is not None: - return session + def _request( + self, files: Optional[Dict[str, BinaryIO]], method: str, params: Optional[Dict], url: str + ) -> httpx.Response: + timeout = self._timeout_sec + headers = self._headers + logger.debug(f"Calling {method} {url} {headers} {params}") + + if method == "GET": + return self._client.request( + method=method, + url=url, + headers=headers, + timeout=timeout, + params=params, + ) - return ( - requests_retry_session(requests.Session(), retries=self._num_retries) - if self._apply_retry_policy - else requests.Session() + if files is not None: + if method != "POST": + raise ValueError( + f"execute_http_request supports only POST for files upload, but {method} was supplied instead" + ) + if "Content-Type" in headers: + headers.pop( + "Content-Type" + ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload + data = params + else: + data = json.dumps(params).encode() if params else None + + return self._client.request( + method=method, + url=url, + headers=headers, + data=data, + timeout=timeout, + files=files, ) + def _init_client(self, client: Optional[httpx.Client]) -> httpx.Client: + if client is not None: + return client + + return _requests_retry_session(retries=self._num_retries) if self._apply_retry_policy else httpx.Client() + def add_headers(self, headers: Dict[str, Any]) -> None: self._headers.update(headers) diff --git a/poetry.lock b/poetry.lock index 53368e37..835c27ae 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "ai21-tokenizer" @@ -29,6 +29,28 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} +[[package]] +name = "anyio" +version = "4.3.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, + {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] + [[package]] name = "authlib" version = "1.3.0" @@ -527,6 +549,62 @@ gitdb = ">=4.0.1,<5" doc = ["sphinx (==4.3.2)", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphinxcontrib-applehelp (>=1.0.2,<=1.0.4)", "sphinxcontrib-devhelp (==1.0.2)", "sphinxcontrib-htmlhelp (>=2.0.0,<=2.0.1)", "sphinxcontrib-qthelp (==1.0.3)", "sphinxcontrib-serializinghtml (==1.1.5)"] test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.5" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.26.0)"] + +[[package]] +name = "httpx" +version = "0.27.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "huggingface-hub" version = "0.23.0" @@ -1510,6 +1588,32 @@ files = [ {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"}, ] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + +[[package]] +name = "tenacity" +version = "8.3.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"}, + {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tokenizers" version = "0.19.1" @@ -1766,4 +1870,4 @@ aws = ["boto3"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "57e341097c164ee614872b24b6dfbd102905b1dbde1cbde5e521b5a7844d15a5" +content-hash = "c6c474d713d3660255aade619131c42d2c57ba74d992b7e20be02275601a5b48" diff --git a/pyproject.toml b/pyproject.toml index f24372c9..d8a88b0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,11 +56,12 @@ packages = [ [tool.poetry.dependencies] python = "^3.8" -requests = "^2.31.0" ai21-tokenizer = ">=0.9.1,<1.0.0" boto3 = { version = "^1.28.82", optional = true } dataclasses-json = "^0.6.3" typing-extensions = "^4.9.0" +httpx = "^0.27.0" +tenacity = "^8.3.0" [tool.poetry.group.dev.dependencies] diff --git a/setup.py b/setup.py index 2ba17979..e65e5028 100755 --- a/setup.py +++ b/setup.py @@ -21,6 +21,6 @@ packages=find_packages(exclude=["tests", "tests.*"]), keywords=["python", "sdk", "ai", "ai21", "jurassic", "ai21-python", "llm"], install_requires=[ - "requests", + "httpx", ], ) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index 02e4d467..68c4efb5 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -1,5 +1,5 @@ import pytest -import requests +import httpx @pytest.fixture @@ -8,5 +8,5 @@ def dummy_api_host() -> str: @pytest.fixture -def mock_requests_session(mocker) -> requests.Session: - return mocker.Mock(spec=requests.Session) +def mock_httpx_client(mocker) -> httpx.Client: + return mocker.Mock(spec=httpx.Client) diff --git a/tests/unittests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py index 86cde2ee..1d730f1a 100644 --- a/tests/unittests/test_ai21_http_client.py +++ b/tests/unittests/test_ai21_http_client.py @@ -1,8 +1,8 @@ import platform from typing import Optional +import httpx import pytest -import requests from ai21.ai21_http_client import AI21HTTPClient from ai21.http_client import HttpClient @@ -94,12 +94,12 @@ def test__execute_http_request__( params, headers, dummy_api_host: str, - mock_requests_session: requests.Session, + mock_httpx_client: httpx.Client, ): response_json = {"test_key": "test_value"} - mock_requests_session.request.return_value = MockResponse(response_json, 200) + mock_httpx_client.request.return_value = MockResponse(response_json, 200) - http_client = HttpClient(session=mock_requests_session) + http_client = HttpClient(client=mock_httpx_client) client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") response = client.execute_http_request(**params) @@ -107,7 +107,7 @@ def test__execute_http_request__( if "files" in params: # We split it because when calling requests with "files", "params" is turned into "data" - mock_requests_session.request.assert_called_once_with( + mock_httpx_client.request.assert_called_once_with( timeout=300, headers=headers, files=params["files"], @@ -116,18 +116,18 @@ def test__execute_http_request__( method=params["method"], ) else: - mock_requests_session.request.assert_called_once_with(timeout=300, headers=headers, **params) + mock_httpx_client.request.assert_called_once_with(timeout=300, headers=headers, **params) def test__execute_http_request__when_files_with_put_method__should_raise_value_error( dummy_api_host: str, - mock_requests_session: requests.Session, + mock_httpx_client: httpx.Client, ): response_json = {"test_key": "test_value"} - http_client = HttpClient(session=mock_requests_session) + http_client = HttpClient(client=mock_httpx_client) client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") - mock_requests_session.request.return_value = MockResponse(response_json, 200) + mock_httpx_client.request.return_value = MockResponse(response_json, 200) with pytest.raises(ValueError): params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} client.execute_http_request(**params) diff --git a/tests/unittests/test_http_client.py b/tests/unittests/test_http_client.py new file mode 100644 index 00000000..73347f09 --- /dev/null +++ b/tests/unittests/test_http_client.py @@ -0,0 +1,39 @@ +import pytest +from unittest.mock import Mock +from urllib.request import Request + +import httpx + +from ai21.errors import ServiceUnavailable +from ai21.http_client import HttpClient, RETRY_ERROR_CODES + +_METHOD = "GET" +_URL = "http://test_url" + + +def test__execute_http_request__when_retry_error_code_once__should_retry_and_succeed(mock_httpx_client: Mock) -> None: + request = Request(method=_METHOD, url=_URL) + retries = 3 + mock_httpx_client.request.side_effect = [ + httpx.Response(status_code=429, request=request), + httpx.Response(status_code=200, request=request, json={"test_key": "test_value"}), + ] + + client = HttpClient(client=mock_httpx_client, num_retries=retries) + client.execute_http_request(method=_METHOD, url=_URL) + assert mock_httpx_client.request.call_count == retries - 1 + + +def test__execute_http_request__when_retry_error__should_retry_and_stop(mock_httpx_client: Mock) -> None: + request = Request(method=_METHOD, url=_URL) + retries = len(RETRY_ERROR_CODES) + + mock_httpx_client.request.side_effect = [ + httpx.Response(status_code=status_code, request=request) for status_code in RETRY_ERROR_CODES + ] + + client = HttpClient(client=mock_httpx_client, num_retries=retries) + with pytest.raises(ServiceUnavailable): + client.execute_http_request(method=_METHOD, url=_URL) + + assert mock_httpx_client.request.call_count == retries From 91e78edf8919262ba1ed409127d52ca04ad2b3cb Mon Sep 17 00:00:00 2001 From: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Date: Thu, 16 May 2024 16:38:10 +0300 Subject: [PATCH 07/14] feat: Stream support jamba (#114) * feat: Added httpx to pyproject * feat: httpx instead of requests * feat: Removed requests * fix: not given * fix: setup * feat: Added tenacity for retry * fix: conftest * test: Added tests * fix: Rename * fix: Modified test * feat: stream support (unfinished) * feat: Added tenacity for retry * test: Added tests * fix: Rename * feat: stream support (unfinished) * fix: single request creation * feat: Support stream_cls * fix: passed response_cls * fix: Removed unnecessary json_to_response * fix: imports * fix: tests * fix: imports * fix: reponse parse * fix: Added two examples to tests * fix: sagemaker tests * test: Added stream tests * fix: comment out failing test * fix: CR * fix: Removed code * docs: Added readme for streaming * fix: condition * docs: readme --- README.md | 24 ++ ai21/ai21_http_client.py | 7 +- ai21/clients/common/chat_base.py | 3 - ai21/clients/common/completion_base.py | 5 +- ai21/clients/common/dataset_base.py | 5 - ai21/clients/common/segmentation_base.py | 3 - .../common/summarize_by_segment_base.py | 3 - .../studio/resources/chat/chat_completions.py | 52 +++- .../clients/studio/resources/studio_answer.py | 4 +- ai21/clients/studio/resources/studio_chat.py | 3 +- .../studio/resources/studio_completion.py | 2 +- .../studio/resources/studio_custom_model.py | 8 +- .../studio/resources/studio_dataset.py | 7 +- ai21/clients/studio/resources/studio_embed.py | 3 +- ai21/clients/studio/resources/studio_gec.py | 3 +- .../studio/resources/studio_improvements.py | 3 +- .../studio/resources/studio_library.py | 17 +- .../studio/resources/studio_paraphrase.py | 3 +- .../studio/resources/studio_resource.py | 65 ++++- .../studio/resources/studio_segmentation.py | 3 +- .../studio/resources/studio_summarize.py | 3 +- .../resources/studio_summarize_by_segment.py | 4 +- ai21/errors.py | 6 + ai21/http_client.py | 30 ++- ai21/models/chat/chat_completion_chunk.py | 27 ++ ai21/services/sagemaker.py | 4 +- ai21/stream.py | 67 +++++ ai21/types.py | 12 +- ai21/utils/typing.py | 12 +- .../studio/chat/stream_chat_completions.py | 21 ++ .../integration_tests/clients/test_studio.py | 4 + .../clients/studio/resources/conftest.py | 248 ++++++++++-------- .../studio/resources/test_studio_resources.py | 30 ++- tests/unittests/conftest.py | 5 + tests/unittests/services/test_sagemaker.py | 18 +- tests/unittests/test_ai21_http_client.py | 22 +- tests/unittests/test_http_client.py | 8 +- tests/unittests/test_stream.py | 64 +++++ 38 files changed, 588 insertions(+), 220 deletions(-) create mode 100644 ai21/models/chat/chat_completion_chunk.py create mode 100644 ai21/stream.py create mode 100644 examples/studio/chat/stream_chat_completions.py create mode 100644 tests/unittests/test_stream.py diff --git a/README.md b/README.md index 98e1c244..019d79bd 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,30 @@ For a more detailed example, see the completion [examples](examples/studio/compl --- +## Streaming + +We currently support streaming for the Chat Completions API in Jamba. + +```python +from ai21 import AI21Client +from ai21.models.chat import ChatMessage + +messages = [ChatMessage(content="What is the meaning of life?", role="user")] + +client = AI21Client() + +response = client.chat.completions.create( + messages=messages, + model="jamba-instruct-preview", + stream=True, +) +for chunk in response: + print(chunk.choices[0].delta.content, end="") + +``` + +--- + ## TSMs AI21 Studio's Task-Specific Models offer a range of powerful tools. These models have been specifically designed for their respective tasks and provide high-quality results while optimizing efficiency. diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 61c80a93..1a3b334f 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,6 +1,8 @@ import platform from typing import Optional, Dict, Any, BinaryIO +import httpx + from ai21.errors import MissingApiKeyError from ai21.http_client import HttpClient from ai21.version import VERSION @@ -76,9 +78,10 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, + stream: bool = False, files: Optional[Dict[str, BinaryIO]] = None, - ): - return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) + ) -> httpx.Response: + return self._http_client.execute_http_request(method=method, url=url, params=params, files=files, stream=stream) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index c89a4e12..4ade9690 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -53,9 +53,6 @@ def create( def completions(self) -> ChatCompletions: pass - def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse: - return ChatResponse.from_dict(json) - def _create_body( self, model: str, diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index b16c3990..49e5b5ea 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Dict, Any +from typing import List, Dict from ai21.models import Penalty, CompletionsResponse from ai21.types import NOT_GIVEN, NotGiven @@ -55,9 +55,6 @@ def create( """ pass - def _json_to_response(self, json: Dict[str, Any]) -> CompletionsResponse: - return CompletionsResponse.from_dict(json) - def _create_body( self, model: str, diff --git a/ai21/clients/common/dataset_base.py b/ai21/clients/common/dataset_base.py index dd5982ab..97bae64a 100644 --- a/ai21/clients/common/dataset_base.py +++ b/ai21/clients/common/dataset_base.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict -from ai21.models import DatasetResponse - class Dataset(ABC): _module_name = "dataset" @@ -40,9 +38,6 @@ def list(self): def get(self, dataset_pid: str): pass - def _json_to_response(self, json: Dict[str, Any]) -> DatasetResponse: - return DatasetResponse.from_dict(json) - def _create_body( self, dataset_name: str, diff --git a/ai21/clients/common/segmentation_base.py b/ai21/clients/common/segmentation_base.py index 6805e70f..074ba8c8 100644 --- a/ai21/clients/common/segmentation_base.py +++ b/ai21/clients/common/segmentation_base.py @@ -18,8 +18,5 @@ def create(self, source: str, source_type: DocumentType, **kwargs) -> Segmentati """ pass - def _json_to_response(self, json: Dict[str, Any]) -> SegmentationResponse: - return SegmentationResponse.from_dict(json) - def _create_body(self, source: str, source_type: str, **kwargs) -> Dict[str, Any]: return {"source": source, "sourceType": source_type, **kwargs} diff --git a/ai21/clients/common/summarize_by_segment_base.py b/ai21/clients/common/summarize_by_segment_base.py index e5163ca9..516d0ebe 100644 --- a/ai21/clients/common/summarize_by_segment_base.py +++ b/ai21/clients/common/summarize_by_segment_base.py @@ -26,9 +26,6 @@ def create( """ pass - def _json_to_response(self, json: Dict[str, Any]) -> SummarizeBySegmentResponse: - return SummarizeBySegmentResponse.from_dict(json) - def _create_body( self, source: str, diff --git a/ai21/clients/studio/resources/chat/chat_completions.py b/ai21/clients/studio/resources/chat/chat_completions.py index 3516f04e..65cc2d40 100644 --- a/ai21/clients/studio/resources/chat/chat_completions.py +++ b/ai21/clients/studio/resources/chat/chat_completions.py @@ -1,10 +1,12 @@ from __future__ import annotations -from typing import List, Optional, Union, Any, Dict +from typing import List, Optional, Union, Any, Dict, Literal, overload from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models.chat import ChatMessage, ChatCompletionResponse from ai21.models import ChatMessage as J2ChatMessage +from ai21.models.chat import ChatMessage, ChatCompletionResponse +from ai21.models.chat.chat_completion_chunk import ChatCompletionChunk +from ai21.stream import Stream from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given @@ -14,6 +16,7 @@ class ChatCompletions(StudioResource): _module_name = "chat/completions" + @overload def create( self, model: str, @@ -23,8 +26,38 @@ def create( top_p: float | NotGiven = NOT_GIVEN, stop: str | List[str] | NotGiven = NOT_GIVEN, n: int | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> ChatCompletionResponse: + pass + + @overload + def create( + self, + model: str, + messages: List[ChatMessage], + stream: Literal[True], + max_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + stop: str | List[str] | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + **kwargs: Any, + ) -> Stream[ChatCompletionChunk]: + pass + + def create( + self, + model: str, + messages: List[ChatMessage], + max_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + stop: str | List[str] | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + **kwargs: Any, + ) -> ChatCompletionResponse | Stream[ChatCompletionChunk]: if any(isinstance(item, J2ChatMessage) for item in messages): raise ValueError( "Please use the ChatMessage class from ai21.models.chat" @@ -39,12 +72,18 @@ def create( max_tokens=max_tokens, top_p=top_p, n=n, + stream=stream or False, **kwargs, ) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post( + url=url, + body=body, + stream=stream or False, + stream_cls=Stream[ChatCompletionChunk], + response_cls=ChatCompletionResponse, + ) def _create_body( self, @@ -55,6 +94,7 @@ def _create_body( top_p: Optional[float] | NotGiven, stop: Optional[Union[str, List[str]]] | NotGiven, n: Optional[int] | NotGiven, + stream: Literal[False] | Literal[True] | NotGiven, **kwargs: Any, ) -> Dict[str, Any]: return remove_not_given( @@ -66,9 +106,7 @@ def _create_body( "topP": top_p, "stop": stop, "n": n, + "stream": stream, **kwargs, } ) - - def _json_to_response(self, json: Dict[str, Any]) -> ChatCompletionResponse: - return ChatCompletionResponse.from_dict(json) diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index 0831e37e..0da4ae01 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -14,6 +14,4 @@ def create( body = self._create_body(context=context, question=question, **kwargs) - response = self._post(url=url, body=body) - - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=AnswerResponse) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index ab6b9cda..daccea1d 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -49,8 +49,7 @@ def create( **kwargs, ) url = f"{self._client.get_base_url()}/{model}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=ChatResponse) @property def completions(self) -> ChatCompletions: diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 8594105b..d2b9cdc7 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -53,4 +53,4 @@ def create( logit_bias=logit_bias, **kwargs, ) - return self._json_to_response(self._post(url=url, body=body)) + return self._post(url=url, body=body, response_cls=CompletionsResponse) diff --git a/ai21/clients/studio/resources/studio_custom_model.py b/ai21/clients/studio/resources/studio_custom_model.py index 1e1fb9b4..61e351d1 100644 --- a/ai21/clients/studio/resources/studio_custom_model.py +++ b/ai21/clients/studio/resources/studio_custom_model.py @@ -25,14 +25,12 @@ def create( num_epochs=num_epochs, **kwargs, ) - self._post(url=url, body=body) + self._post(url=url, body=body, response_cls=None) def list(self) -> List[CustomBaseModelResponse]: url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._get(url=url) - - return [self._json_to_response(r) for r in response] + return self._get(url=url, response_cls=List[CustomBaseModelResponse]) def get(self, resource_id: str) -> CustomBaseModelResponse: url = f"{self._client.get_base_url()}/{self._module_name}/{resource_id}" - return self._json_to_response(self._get(url=url)) + return self._get(url=url, response_cls=CustomBaseModelResponse) diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py index 0620fa00..10f26ddd 100644 --- a/ai21/clients/studio/resources/studio_dataset.py +++ b/ai21/clients/studio/resources/studio_dataset.py @@ -33,14 +33,11 @@ def create( ) def list(self) -> List[DatasetResponse]: - response = self._get(url=self._base_url()) - return [self._json_to_response(r) for r in response] + return self._get(url=self._base_url(), response_cls=List[DatasetResponse]) def get(self, dataset_pid: str) -> DatasetResponse: url = f"{self._base_url()}/{dataset_pid}" - response = self._get(url=url) - - return self._json_to_response(response) + return self._get(url=url, response_cls=DatasetResponse) def _base_url(self) -> str: return f"{self._client.get_base_url()}/{self._module_name}" diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py index c6af637b..e45b6269 100644 --- a/ai21/clients/studio/resources/studio_embed.py +++ b/ai21/clients/studio/resources/studio_embed.py @@ -9,6 +9,5 @@ class StudioEmbed(StudioResource, Embed): def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(texts=texts, type=type, **kwargs) - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=EmbedResponse) diff --git a/ai21/clients/studio/resources/studio_gec.py b/ai21/clients/studio/resources/studio_gec.py index e5b2b2ff..a8752c9c 100644 --- a/ai21/clients/studio/resources/studio_gec.py +++ b/ai21/clients/studio/resources/studio_gec.py @@ -7,6 +7,5 @@ class StudioGEC(StudioResource, GEC): def create(self, text: str, **kwargs) -> GECResponse: body = self._create_body(text=text, **kwargs) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=GECResponse) diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index f767684c..88ea996c 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -13,6 +13,5 @@ def create(self, text: str, types: List[ImprovementType], **kwargs) -> Improveme url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(text=text, types=types, **kwargs) - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=ImprovementsResponse) diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index 8b6e6622..e177df94 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Optional, List + from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import FileResponse, LibraryAnswerResponse, LibrarySearchResponse @@ -34,15 +35,14 @@ def create( files = {"file": open(file_path, "rb")} body = remove_not_given({"path": path, "labels": labels, "publicUrl": public_url, **kwargs}) - raw_response = self._post(url=url, files=files, body=body) + raw_response = self._post(url=url, files=files, body=body, response_cls=dict) return raw_response["fileId"] def get(self, file_id: str) -> FileResponse: url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}" - raw_response = self._get(url=url) - return FileResponse.from_dict(raw_response) + return self._get(url=url, response_cls=FileResponse) def list( self, @@ -53,9 +53,8 @@ def list( ) -> List[FileResponse]: url = f"{self._client.get_base_url()}/{self._module_name}" params = remove_not_given({"offset": offset, "limit": limit}) - raw_response = self._get(url=url, params=params) - return [FileResponse.from_dict(file) for file in raw_response] + return self._get(url=url, params=params, response_cls=List[FileResponse]) def update( self, @@ -102,8 +101,8 @@ def create( **kwargs, } ) - raw_response = self._post(url=url, body=body) - return LibrarySearchResponse.from_dict(raw_response) + + return self._post(url=url, body=body, response_cls=LibrarySearchResponse) class LibraryAnswer(StudioResource): @@ -128,5 +127,5 @@ def create( **kwargs, } ) - raw_response = self._post(url=url, body=body) - return LibraryAnswerResponse.from_dict(raw_response) + + return self._post(url=url, body=body, response_cls=LibraryAnswerResponse) diff --git a/ai21/clients/studio/resources/studio_paraphrase.py b/ai21/clients/studio/resources/studio_paraphrase.py index 5def5608..25764e46 100644 --- a/ai21/clients/studio/resources/studio_paraphrase.py +++ b/ai21/clients/studio/resources/studio_paraphrase.py @@ -23,6 +23,5 @@ def create( **kwargs, ) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=ParaphraseResponse) diff --git a/ai21/clients/studio/resources/studio_resource.py b/ai21/clients/studio/resources/studio_resource.py index 8ece396e..8f4be991 100644 --- a/ai21/clients/studio/resources/studio_resource.py +++ b/ai21/clients/studio/resources/studio_resource.py @@ -1,10 +1,16 @@ from __future__ import annotations +import json from abc import ABC -from typing import Any, Dict, Optional, BinaryIO +from typing import Any, Dict, Optional, BinaryIO, get_origin + +import httpx from ai21.ai21_http_client import AI21HTTPClient +from ai21.types import ResponseT, StreamT +from ai21.utils.typing import extract_type + class StudioResource(ABC): def __init__(self, client: AI21HTTPClient): @@ -14,23 +20,64 @@ def _post( self, url: str, body: Dict[str, Any], + response_cls: Optional[ResponseT] = None, + stream_cls: Optional[StreamT] = None, + stream: bool = False, files: Optional[Dict[str, BinaryIO]] = None, - ) -> Dict[str, Any]: - return self._client.execute_http_request( + ) -> ResponseT | StreamT: + response = self._client.execute_http_request( method="POST", url=url, + stream=stream, params=body or {}, files=files, ) - def _get(self, url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - return self._client.execute_http_request(method="GET", url=url, params=params or {}) + return self._cast_response(stream=stream, response=response, response_cls=response_cls, stream_cls=stream_cls) + + def _get( + self, url: str, response_cls: Optional[ResponseT] = None, params: Optional[Dict[str, Any]] = None + ) -> ResponseT | StreamT: + response = self._client.execute_http_request(method="GET", url=url, params=params or {}) + return self._cast_response(response=response, response_cls=response_cls) - def _put(self, url: str, body: Dict[str, Any] = None) -> Dict[str, Any]: - return self._client.execute_http_request(method="PUT", url=url, params=body or {}) + def _put( + self, url: str, response_cls: Optional[ResponseT] = None, body: Dict[str, Any] = None + ) -> ResponseT | StreamT: + response = self._client.execute_http_request(method="PUT", url=url, params=body or {}) + return self._cast_response(response=response, response_cls=response_cls) - def _delete(self, url: str) -> Dict[str, Any]: - return self._client.execute_http_request( + def _delete(self, url: str, response_cls: Optional[ResponseT] = None) -> ResponseT | StreamT: + response = self._client.execute_http_request( method="DELETE", url=url, ) + return self._cast_response(response=response, response_cls=response_cls) + + def _cast_response( + self, + response: httpx.Response, + response_cls: Optional[ResponseT], + stream_cls: Optional[StreamT] = None, + stream: bool = False, + ) -> ResponseT | StreamT | None: + if stream and stream_cls is not None: + cast_to = extract_type(stream_cls) + return stream_cls(cast_to=cast_to, response=response) + + if response_cls is None: + return None + + if response_cls == dict: + return response.json() + + if response_cls == str: + return json.loads(response.json()) + + origin_type = get_origin(response_cls) + + if origin_type is not None and origin_type == list: + subtype = extract_type(response_cls) + return [subtype.from_dict(item) for item in response.json()] + + return response_cls.from_dict(response.json()) diff --git a/ai21/clients/studio/resources/studio_segmentation.py b/ai21/clients/studio/resources/studio_segmentation.py index e8e44efe..a2aee960 100644 --- a/ai21/clients/studio/resources/studio_segmentation.py +++ b/ai21/clients/studio/resources/studio_segmentation.py @@ -7,6 +7,5 @@ class StudioSegmentation(StudioResource, Segmentation): def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: body = self._create_body(source=source, source_type=source_type.value, **kwargs) url = f"{self._client.get_base_url()}/{self._module_name}" - raw_response = self._post(url=url, body=body) - return self._json_to_response(raw_response) + return self._post(url=url, body=body, response_cls=SegmentationResponse) diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index 7fd84756..6ba4f9fe 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -25,6 +25,5 @@ def create( **kwargs, ) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + return self._post(url=url, body=body, response_cls=SummarizeResponse) diff --git a/ai21/clients/studio/resources/studio_summarize_by_segment.py b/ai21/clients/studio/resources/studio_summarize_by_segment.py index 292dcbaf..abb1705e 100644 --- a/ai21/clients/studio/resources/studio_summarize_by_segment.py +++ b/ai21/clients/studio/resources/studio_summarize_by_segment.py @@ -16,5 +16,5 @@ def create( **kwargs, ) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=body) - return self._json_to_response(response) + + return self._post(url=url, body=body, response_cls=SummarizeBySegmentResponse) diff --git a/ai21/errors.py b/ai21/errors.py index 4a0f8c92..7d5bae8d 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -87,3 +87,9 @@ class EmptyMandatoryListError(AI21Error): def __init__(self, key: str): message = f"Supplied {key} is empty. At least one element should be present in the list" super().__init__(message) + + +class StreamingDecodeError(AI21Error): + def __init__(self, chunk: str): + message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format" + super().__init__(message) diff --git a/ai21/http_client.py b/ai21/http_client.py index e199cb1b..7bb1343f 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -18,8 +18,8 @@ DEFAULT_TIMEOUT_SEC = 300 DEFAULT_NUM_RETRIES = 0 -TIME_BETWEEN_RETRIES = 1 RETRY_BACK_OFF_FACTOR = 0.5 +TIME_BETWEEN_RETRIES = 1 RETRY_ERROR_CODES = (408, 429, 500, 503) RETRY_METHOD_WHITELIST = ["GET", "POST", "PUT"] @@ -76,10 +76,11 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, + stream: bool = False, files: Optional[Dict[str, BinaryIO]] = None, - ): + ) -> httpx.Response: try: - response = self._request(files=files, method=method, params=params, url=url) + response = self._request(files=files, method=method, params=params, url=url, stream=stream) except RetryError as retry_error: last_attempt = retry_error.last_attempt @@ -99,24 +100,27 @@ def execute_http_request( logger.error(f"Calling {method} {url} failed with a non-200 response code: {response.status_code}") handle_non_success_response(response.status_code, response.text) - return response.json() + return response def _request( - self, files: Optional[Dict[str, BinaryIO]], method: str, params: Optional[Dict], url: str + self, + files: Optional[Dict[str, BinaryIO]], + method: str, + params: Optional[Dict], + url: str, + stream: bool, ) -> httpx.Response: timeout = self._timeout_sec headers = self._headers logger.debug(f"Calling {method} {url} {headers} {params}") if method == "GET": - return self._client.request( - method=method, - url=url, - headers=headers, - timeout=timeout, - params=params, + request = self._client.build_request( + method=method, url=url, headers=headers, timeout=timeout, params=params ) + return self._client.send(request=request, stream=stream) + if files is not None: if method != "POST": raise ValueError( @@ -130,7 +134,7 @@ def _request( else: data = json.dumps(params).encode() if params else None - return self._client.request( + request = self._client.build_request( method=method, url=url, headers=headers, @@ -139,6 +143,8 @@ def _request( files=files, ) + return self._client.send(request=request, stream=stream) + def _init_client(self, client: Optional[httpx.Client]) -> httpx.Client: if client is not None: return client diff --git a/ai21/models/chat/chat_completion_chunk.py b/ai21/models/chat/chat_completion_chunk.py new file mode 100644 index 00000000..2a63e63c --- /dev/null +++ b/ai21/models/chat/chat_completion_chunk.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from typing import Optional, List + +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from ai21.models.logprobs import Logprobs +from ai21.models.usage_info import UsageInfo + + +@dataclass +class ChoiceDelta(AI21BaseModelMixin): + content: Optional[str] = None + role: Optional[str] = None + + +@dataclass +class ChoicesChunk(AI21BaseModelMixin): + index: int + delta: ChoiceDelta + logprobs: Optional[Logprobs] = None + finish_reason: Optional[str] = None + + +@dataclass +class ChatCompletionChunk(AI21BaseModelMixin): + id: str + choices: List[ChoicesChunk] + usage: Optional[UsageInfo] = None diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index 627fbe2e..3132cd1a 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -31,7 +31,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE }, ) - arn = response["arn"] + arn = response.json()["arn"] if not arn: raise ModelPackageDoesntExistError(model_name=model_name, region=region, version=version) @@ -53,7 +53,7 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: }, ) - return response["versions"] + return response.json()["versions"] @classmethod def _create_ai21_http_client(cls) -> AI21HTTPClient: diff --git a/ai21/stream.py b/ai21/stream.py new file mode 100644 index 00000000..bd324000 --- /dev/null +++ b/ai21/stream.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import json +from typing import TypeVar, Generic, Iterator, Optional + +import httpx + +from ai21.errors import StreamingDecodeError + +_T = TypeVar("_T") +_SSE_DATA_PREFIX = "data: " +_SSE_DONE_MSG = "[DONE]" + + +class Stream(Generic[_T]): + response: httpx.Response + + def __init__( + self, + *, + cast_to: type[_T], + response: httpx.Response, + ): + self.response = response + self.cast_to = cast_to + self._decoder = _SSEDecoder() + self._iterator = self.__stream__() + + def __next__(self) -> _T: + return self._iterator.__next__() + + def __iter__(self) -> Iterator[_T]: + for item in self._iterator: + yield item + + def __stream__(self) -> Iterator[_T]: + for chunk in self._decoder.iter(self.response.iter_lines()): + if chunk.endswith(_SSE_DONE_MSG): + break + + try: + chunk = json.loads(chunk) + if hasattr(self.cast_to, "from_dict"): + yield self.cast_to.from_dict(chunk) + else: + yield self.cast_to(**chunk) + except json.JSONDecodeError: + raise StreamingDecodeError(chunk) + + +class _SSEDecoder: + def iter(self, iterator: Iterator[str]): + for line in iterator: + line = line.strip() + decoded_line = self._decode(line) + + if decoded_line is not None: + yield decoded_line + + def _decode(self, line: str) -> Optional[str]: + if not line: + return None + + if line.startswith(_SSE_DATA_PREFIX): + return line.strip(_SSE_DATA_PREFIX) + + raise StreamingDecodeError(f"Invalid SSE line: {line}") diff --git a/ai21/types.py b/ai21/types.py index 137c938d..13a7a035 100644 --- a/ai21/types.py +++ b/ai21/types.py @@ -1,4 +1,14 @@ -from typing_extensions import Literal +from typing import Any, Union, List + +import httpx +from typing_extensions import Literal, TypeVar, TYPE_CHECKING +from ai21.stream import Stream + +if TYPE_CHECKING: + from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin # noqa + +ResponseT = TypeVar("_ResponseT", bound=Union["AI21BaseModelMixin", str, httpx.Response, List[Any]]) +StreamT = TypeVar("_StreamT", bound=Stream[Any]) # Sentinel class used until PEP 0661 is accepted diff --git a/ai21/utils/typing.py b/ai21/utils/typing.py index ae77329b..be0244c7 100644 --- a/ai21/utils/typing.py +++ b/ai21/utils/typing.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, get_args, cast from ai21.types import NotGiven @@ -20,3 +20,13 @@ def to_lower_camel_case(snake_str: str) -> str: # with the 'capitalize' method and join them together. camel_string = to_camel_case(snake_str) return snake_str[0].lower() + camel_string[1:] + + +def extract_type(type_to_extract: Any) -> type: + args = get_args(type_to_extract) + try: + return cast(type, args[0]) + except IndexError as err: + raise RuntimeError( + f"Expected type {type_to_extract} to have a type argument at index 0 but it did not" + ) from err diff --git a/examples/studio/chat/stream_chat_completions.py b/examples/studio/chat/stream_chat_completions.py new file mode 100644 index 00000000..fd079962 --- /dev/null +++ b/examples/studio/chat/stream_chat_completions.py @@ -0,0 +1,21 @@ +from ai21 import AI21Client +from ai21.models.chat import ChatMessage + +system = "You're a support engineer in a SaaS company" +messages = [ + ChatMessage(content=system, role="system"), + ChatMessage(content="Hello, I need help with a signup process.", role="user"), + ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), + ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), +] + +client = AI21Client() + +response = client.chat.completions.create( + messages=messages, + model="jamba-instruct-preview", + max_tokens=100, + stream=True, +) +for chunk in response: + print(chunk.choices[0].delta.content, end="") diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index e5500d22..2d2d78f8 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -27,6 +27,8 @@ ("summarize.py",), ("summarize_by_segment.py",), ("tokenization.py",), + ("chat/chat_completions.py",), + # ("chat/stream_chat_completions.py",), # Uncomment when streaming is supported in production # ("custom_model.py", ), # ('custom_model_completion.py', ), # ("dataset.py", ), @@ -45,6 +47,8 @@ "when_summarize__should_return_ok", "when_summarize_by_segment__should_return_ok", "when_tokenization__should_return_ok", + "when_chat_completions__should_return_ok", + # "when_stream_chat_completions__should_return_ok", # "when_custom_model__should_return_ok", # "when_custom_model_completion__should_return_ok", # "when_dataset__should_return_ok", diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 132f95ad..627a2eee 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -1,4 +1,5 @@ import pytest +import httpx from pytest_mock import MockerFixture from ai21.ai21_http_client import AI21HTTPClient @@ -18,39 +19,24 @@ ChatMessage, RoleType, ChatResponse, - ChatOutput, - FinishReason, - Prompt, - Completion, - CompletionData, - CompletionFinishReason, CompletionsResponse, EmbedType, EmbedResponse, - EmbedResult, GECResponse, - Correction, - CorrectionType, ImprovementType, ImprovementsResponse, - Improvement, ParaphraseStyleType, ParaphraseResponse, - Suggestion, DocumentType, SegmentationResponse, SummaryMethod, SummarizeResponse, SummarizeBySegmentResponse, - SegmentSummary, ) from ai21.models.chat import ( ChatMessage as ChatCompletionChatMessage, ChatCompletionResponse, - ChatCompletionResponseChoice, ) -from ai21.models.responses.segmentation_response import Segment -from ai21.models.usage_info import UsageInfo from ai21.utils.typing import to_lower_camel_case @@ -59,9 +45,18 @@ def mock_ai21_studio_client(mocker: MockerFixture) -> AI21HTTPClient: return mocker.MagicMock(spec=AI21HTTPClient) +@pytest.fixture +def mock_successful_httpx_response(mocker: MockerFixture) -> httpx.Response: + mock_httpx_response = mocker.Mock(spec=httpx.Response) + mock_httpx_response.status_code = 200 + + return mock_httpx_response + + def get_studio_answer(): _DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" _DUMMY_QUESTION = "What is the answer?" + json_response = {"id": "some-id", "answer_in_context": True, "answer": "42"} return ( StudioAnswer, @@ -71,7 +66,8 @@ def get_studio_answer(): "context": _DUMMY_CONTEXT, "question": _DUMMY_QUESTION, }, - AnswerResponse(id="some-id", answer_in_context=True, answer="42"), + httpx.Response(status_code=200, json=json_response), + AnswerResponse.from_dict(json_response), ) @@ -85,6 +81,15 @@ def get_studio_chat(): ), ] _DUMMY_SYSTEM = "You're a support engineer in a SaaS company" + json_response = { + "outputs": [ + { + "text": "Hello, I need help with a signup process.", + "role": "user", + "finishReason": {"reason": "dummy_reason", "length": 1, "sequence": "1"}, + } + ] + } return ( StudioChat, @@ -105,15 +110,8 @@ def get_studio_chat(): "presencePenalty": None, "countPenalty": None, }, - ChatResponse( - outputs=[ - ChatOutput( - text="Hello, I need help with a signup process.", - role="user", - finish_reason=FinishReason(reason="dummy_reason", length=1, sequence="1"), - ) - ] - ), + httpx.Response(status_code=200, json=json_response), + ChatResponse.from_dict(json_response), ) @@ -127,6 +125,25 @@ def get_chat_completions(): ), ] _EXPECTED_SERIALIZED_MESSAGES = [message.to_dict() for message in _DUMMY_MESSAGES] + json_response = { + "id": "some-id", + "choices": [ + { + "index": 0, + "message": { + "content": "Hello, I need help with a signup process.", + "role": "user", + }, + "finishReason": "dummy_reason", + "logprobs": None, + } + ], + "usage": { + "promptTokens": 10, + "completionTokens": 20, + "totalTokens": 30, + }, + } return ( ChatCompletions, @@ -135,31 +152,26 @@ def get_chat_completions(): { "model": _DUMMY_MODEL, "messages": _EXPECTED_SERIALIZED_MESSAGES, + "stream": False, }, - ChatCompletionResponse( - id="some-id", - choices=[ - ChatCompletionResponseChoice( - index=0, - message=ChatCompletionChatMessage( - content="Hello, I need help with a signup process.", role=RoleType.USER - ), - finish_reason="dummy_reason", - logprobs=None, - ) - ], - usage=UsageInfo( - prompt_tokens=10, - completion_tokens=20, - total_tokens=30, - ), - ), + httpx.Response(status_code=200, json=json_response), + ChatCompletionResponse.from_dict(json_response), ) def get_studio_completion(**kwargs): _DUMMY_MODEL = "dummy-completion-model" _DUMMY_PROMPT = "dummy-prompt" + json_response = { + "id": "some-id", + "completions": [ + { + "data": {"text": "dummy-completion", "tokens": []}, + "finishReason": {"reason": "dummy_reason", "length": 1}, + } + ], + "prompt": {"text": "dummy-prompt"}, + } return ( StudioCompletion, @@ -170,20 +182,20 @@ def get_studio_completion(**kwargs): "prompt": _DUMMY_PROMPT, **{to_lower_camel_case(k): v for k, v in kwargs.items()}, }, - CompletionsResponse( - id="some-id", - completions=[ - Completion( - data=CompletionData(text="dummy-completion", tokens=[]), - finish_reason=CompletionFinishReason(reason="dummy_reason", length=1), - ) - ], - prompt=Prompt(text="dummy-prompt"), - ), + httpx.Response(status_code=200, json=json_response), + CompletionsResponse.from_dict(json_response), ) def get_studio_embed(): + json_response = { + "id": "some-id", + "results": [ + {"embedding": [1.0, 2.0, 3.0]}, + {"embedding": [4.0, 5.0, 6.0]}, + ], + } + return ( StudioEmbed, {"texts": ["text0", "text1"], "type": EmbedType.QUERY}, @@ -192,18 +204,26 @@ def get_studio_embed(): "texts": ["text0", "text1"], "type": EmbedType.QUERY.value, }, - EmbedResponse( - id="some-id", - results=[ - EmbedResult([1.0, 2.0, 3.0]), - EmbedResult([4.0, 5.0, 6.0]), - ], - ), + httpx.Response(status_code=200, json=json_response), + EmbedResponse.from_dict(json_response), ) def get_studio_gec(): + json_response = { + "id": "some-id", + "corrections": [ + { + "suggestion": "text to fix", + "startIndex": 9, + "endIndex": 10, + "originalText": "text to fi", + "correctionType": "Spelling", + } + ], + } text = "text to fi" + return ( StudioGEC, {"text": text}, @@ -211,24 +231,27 @@ def get_studio_gec(): { "text": text, }, - GECResponse( - id="some-id", - corrections=[ - Correction( - suggestion="text to fix", - start_index=9, - end_index=10, - original_text=text, - correction_type=CorrectionType.SPELLING, - ) - ], - ), + httpx.Response(status_code=200, json=json_response), + GECResponse.from_dict(json_response), ) def get_studio_improvements(): + json_response = { + "id": "some-id", + "improvements": [ + { + "suggestions": ["This text is improved"], + "startIndex": 0, + "endIndex": 15, + "originalText": "text to improve", + "improvementType": "FLUENCY", + } + ], + } text = "text to improve" types = [ImprovementType.FLUENCY] + return ( StudioImprovements, {"text": text, "types": types}, @@ -237,18 +260,8 @@ def get_studio_improvements(): "text": text, "types": types, }, - ImprovementsResponse( - id="some-id", - improvements=[ - Improvement( - suggestions=["This text is improved"], - start_index=0, - end_index=15, - original_text=text, - improvement_type=ImprovementType.FLUENCY, - ) - ], - ), + httpx.Response(status_code=200, json=json_response), + ImprovementsResponse.from_dict(json_response), ) @@ -257,6 +270,15 @@ def get_studio_paraphrase(): style = ParaphraseStyleType.CASUAL start_index = 0 end_index = 10 + json_response = { + "id": "some-id", + "suggestions": [ + { + "text": "This text is paraphrased", + } + ], + } + return ( StudioParaphrase, {"text": text, "style": style, "start_index": start_index, "end_index": end_index}, @@ -267,13 +289,24 @@ def get_studio_paraphrase(): "startIndex": start_index, "endIndex": end_index, }, - ParaphraseResponse(id="some-id", suggestions=[Suggestion(text="This text is paraphrased")]), + httpx.Response(status_code=200, json=json_response), + ParaphraseResponse.from_dict(json_response), ) def get_studio_segmentation(): source = "segmentation text" source_type = DocumentType.TEXT + json_response = { + "id": "some-id", + "segments": [ + { + "segmentText": "This text is segmented", + "segmentType": "segment_type", + } + ], + } + return ( StudioSegmentation, {"source": source, "source_type": source_type}, @@ -282,9 +315,8 @@ def get_studio_segmentation(): "source": source, "sourceType": source_type, }, - SegmentationResponse( - id="some-id", segments=[Segment(segment_text="This text is segmented", segment_type="segment_type")] - ), + httpx.Response(status_code=200, json=json_response), + SegmentationResponse.from_dict(json_response), ) @@ -293,6 +325,11 @@ def get_studio_summarization(): source_type = "TEXT" focus = "text" summary_method = SummaryMethod.FULL_DOCUMENT + json_response = { + "id": "some-id", + "summary": "This text is summarized", + } + return ( StudioSummarize, {"source": source, "source_type": source_type, "focus": focus, "summary_method": summary_method}, @@ -303,10 +340,8 @@ def get_studio_summarization(): "focus": focus, "summaryMethod": summary_method, }, - SummarizeResponse( - id="some-id", - summary="This text is summarized", - ), + httpx.Response(status_code=200, json=json_response), + SummarizeResponse.from_dict(json_response), ) @@ -314,6 +349,20 @@ def get_studio_summarize_by_segment(): source = "text to summarize" source_type = "TEXT" focus = "text" + json_response = { + "id": "some-id", + "segments": [ + { + "summary": "This text is summarized", + "segmentText": "This text is segmented", + "segmentHtml": "", + "segmentType": "segment_type", + "hasSummary": True, + "highlights": [], + } + ], + } + return ( StudioSummarizeBySegment, {"source": source, "source_type": source_type, "focus": focus}, @@ -323,17 +372,6 @@ def get_studio_summarize_by_segment(): "sourceType": source_type, "focus": focus, }, - SummarizeBySegmentResponse( - id="some-id", - segments=[ - SegmentSummary( - summary="This text is summarized", - segment_text="This text is segmented", - segment_type="segment_type", - segment_html=None, - has_summary=True, - highlights=[], - ) - ], - ), + httpx.Response(status_code=200, json=json_response), + SummarizeBySegmentResponse.from_dict(json_response), ) diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index 33c02318..f5ad3196 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -1,7 +1,7 @@ from typing import TypeVar, Callable import pytest - +import httpx from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_resource import StudioResource @@ -43,7 +43,14 @@ class TestStudioResources: "studio_summarization", "studio_summarize_by_segment", ], - argnames=["studio_resource", "function_body", "url_suffix", "expected_body", "expected_response"], + argnames=[ + "studio_resource", + "function_body", + "url_suffix", + "expected_body", + "expected_httpx_response", + "expected_response", + ], argvalues=[ (get_studio_answer()), (get_studio_chat()), @@ -59,16 +66,17 @@ class TestStudioResources: (get_studio_summarize_by_segment()), ], ) - def test__create__should_return_answer_response( + def test__create__should_return_response( self, studio_resource: Callable[[AI21HTTPClient], T], function_body, url_suffix: str, expected_body, - expected_response, + expected_httpx_response, + expected_response: AnswerResponse, mock_ai21_studio_client: AI21HTTPClient, ): - mock_ai21_studio_client.execute_http_request.return_value = expected_response.to_dict() + mock_ai21_studio_client.execute_http_request.return_value = expected_httpx_response mock_ai21_studio_client.get_base_url.return_value = _BASE_URL resource = studio_resource(mock_ai21_studio_client) @@ -82,12 +90,19 @@ def test__create__should_return_answer_response( method="POST", url=f"{_BASE_URL}/{url_suffix}", params=expected_body, + stream=False, files=None, ) - def test__create__when_pass_kwargs__should_pass_to_request(self, mock_ai21_studio_client: AI21HTTPClient): + def test__create__when_pass_kwargs__should_pass_to_request( + self, + mock_ai21_studio_client: AI21HTTPClient, + mock_successful_httpx_response: httpx.Response, + ): expected_answer = AnswerResponse(id="some-id", answer_in_context=True, answer="42") - mock_ai21_studio_client.execute_http_request.return_value = expected_answer.to_dict() + mock_successful_httpx_response.json.return_value = expected_answer.to_dict() + + mock_ai21_studio_client.execute_http_request.return_value = mock_successful_httpx_response mock_ai21_studio_client.get_base_url.return_value = _BASE_URL studio_answer = StudioAnswer(mock_ai21_studio_client) @@ -105,5 +120,6 @@ def test__create__when_pass_kwargs__should_pass_to_request(self, mock_ai21_studi "question": _DUMMY_QUESTION, "some_dummy_kwargs": "some_dummy_value", }, + stream=False, files=None, ) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index 68c4efb5..f8efe8ae 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -10,3 +10,8 @@ def dummy_api_host() -> str: @pytest.fixture def mock_httpx_client(mocker) -> httpx.Client: return mocker.Mock(spec=httpx.Client) + + +@pytest.fixture +def mock_httpx_response(mocker) -> httpx.Response: + return mocker.Mock(spec=httpx.Response) diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index c6f9e165..1e8fdc04 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,3 +1,4 @@ +import httpx import pytest from ai21 import ModelPackageDoesntExistError @@ -8,28 +9,29 @@ class TestSageMakerService: - def test__get_model_package_arn__should_return_model_package_arn(self): - expected_response = { + def test__get_model_package_arn__should_return_model_package_arn(self, mock_httpx_response: httpx.Response): + mock_httpx_response.json.return_value = { "arn": _DUMMY_ARN, "versions": _DUMMY_VERSIONS, } - SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + SageMakerStub.ai21_http_client.execute_http_request.return_value = mock_httpx_response actual_model_package_arn = SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") assert actual_model_package_arn == _DUMMY_ARN - def test__get_model_package_arn__when_no_arn__should_raise_error(self): - SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} + def test__get_model_package_arn__when_no_arn__should_raise_error(self, mock_httpx_response: httpx.Response): + mock_httpx_response.json.return_value = {"arn": []} + SageMakerStub.ai21_http_client.execute_http_request.return_value = mock_httpx_response with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") - def test__list_model_package_versions__should_return_model_package_arn(self): - expected_response = { + def test__list_model_package_versions__should_return_model_package_arn(self, mock_httpx_response: httpx.Response): + mock_httpx_response.json.return_value = { "versions": _DUMMY_VERSIONS, } - SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + SageMakerStub.ai21_http_client.execute_http_request.return_value = mock_httpx_response actual_model_package_arn = SageMakerStub.list_model_package_versions(model_name="j2-mid", region="us-east-1") diff --git a/tests/unittests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py index 1d730f1a..3ae8f6c6 100644 --- a/tests/unittests/test_ai21_http_client.py +++ b/tests/unittests/test_ai21_http_client.py @@ -1,5 +1,7 @@ import platform from typing import Optional +from unittest.mock import Mock +from urllib.request import Request import httpx import pytest @@ -85,7 +87,13 @@ def test__get_base_url(api_host: Optional[str], expected_api_host: str): argvalues=[ ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), ( - {"method": "POST", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}}, + { + "method": "POST", + "url": "test_url", + "params": {"foo": "bar"}, + "stream": False, + "files": {"file": "test_file"}, + }, _EXPECTED_POST_FILE_HEADERS, ), ], @@ -97,17 +105,19 @@ def test__execute_http_request__( mock_httpx_client: httpx.Client, ): response_json = {"test_key": "test_value"} - mock_httpx_client.request.return_value = MockResponse(response_json, 200) + mock_response = Mock(spec=Request) + mock_httpx_client.build_request.return_value = mock_response + mock_httpx_client.send.return_value = MockResponse(response_json, 200) http_client = HttpClient(client=mock_httpx_client) client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") response = client.execute_http_request(**params) - assert response == response_json + assert response.json() == response_json if "files" in params: # We split it because when calling requests with "files", "params" is turned into "data" - mock_httpx_client.request.assert_called_once_with( + mock_httpx_client.build_request.assert_called_once_with( timeout=300, headers=headers, files=params["files"], @@ -116,7 +126,9 @@ def test__execute_http_request__( method=params["method"], ) else: - mock_httpx_client.request.assert_called_once_with(timeout=300, headers=headers, **params) + mock_httpx_client.build_request.assert_called_once_with(timeout=300, headers=headers, **params) + + mock_httpx_client.send.assert_called_once_with(request=mock_response, stream=False) def test__execute_http_request__when_files_with_put_method__should_raise_value_error( diff --git a/tests/unittests/test_http_client.py b/tests/unittests/test_http_client.py index 73347f09..0342c426 100644 --- a/tests/unittests/test_http_client.py +++ b/tests/unittests/test_http_client.py @@ -14,21 +14,21 @@ def test__execute_http_request__when_retry_error_code_once__should_retry_and_succeed(mock_httpx_client: Mock) -> None: request = Request(method=_METHOD, url=_URL) retries = 3 - mock_httpx_client.request.side_effect = [ + mock_httpx_client.send.side_effect = [ httpx.Response(status_code=429, request=request), httpx.Response(status_code=200, request=request, json={"test_key": "test_value"}), ] client = HttpClient(client=mock_httpx_client, num_retries=retries) client.execute_http_request(method=_METHOD, url=_URL) - assert mock_httpx_client.request.call_count == retries - 1 + assert mock_httpx_client.send.call_count == retries - 1 def test__execute_http_request__when_retry_error__should_retry_and_stop(mock_httpx_client: Mock) -> None: request = Request(method=_METHOD, url=_URL) retries = len(RETRY_ERROR_CODES) - mock_httpx_client.request.side_effect = [ + mock_httpx_client.send.side_effect = [ httpx.Response(status_code=status_code, request=request) for status_code in RETRY_ERROR_CODES ] @@ -36,4 +36,4 @@ def test__execute_http_request__when_retry_error__should_retry_and_stop(mock_htt with pytest.raises(ServiceUnavailable): client.execute_http_request(method=_METHOD, url=_URL) - assert mock_httpx_client.request.call_count == retries + assert mock_httpx_client.send.call_count == retries diff --git a/tests/unittests/test_stream.py b/tests/unittests/test_stream.py new file mode 100644 index 00000000..7abd2563 --- /dev/null +++ b/tests/unittests/test_stream.py @@ -0,0 +1,64 @@ +import json +from dataclasses import dataclass +from typing import AsyncIterable + +import httpx +import pytest + +from ai21.errors import StreamingDecodeError +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from ai21.stream import Stream + + +@dataclass +class StubStreamObject(AI21BaseModelMixin): + id: str + name: str + + +def async_byte_stream() -> AsyncIterable[bytes]: + for i in range(10): + data = {"id": f"some-{i}", "name": f"some-name-{i}"} + msg = f"data: {json.dumps(data)}\r\n" + yield msg.encode("utf-8") + + +def async_byte_bad_stream_prefix() -> AsyncIterable[bytes]: + msg = "bad_stream: {}\r\n" + yield msg.encode("utf-8") + + +def async_byte_bad_stream_json_format() -> AsyncIterable[bytes]: + msg = "data: not a json format\r\n" + yield msg.encode("utf-8") + + +def test_stream_object_when_json_string_ok__should_be_ok(): + stream = async_byte_stream() + response = httpx.Response(status_code=200, content=stream) + stream_obj = Stream[StubStreamObject](response=response, cast_to=StubStreamObject) + + for i, chunk in enumerate(stream_obj): + assert isinstance(chunk, StubStreamObject) + assert chunk.name == f"some-name-{i}" + assert chunk.id == f"some-{i}" + + +@pytest.mark.parametrize( + ids=[ + "bad_stream_data_prefix", + "bad_stream_json_format", + ], + argnames=["stream"], + argvalues=[ + (async_byte_bad_stream_prefix(),), + (async_byte_bad_stream_json_format(),), + ], +) +def test_stream_object_when_bad_json__should_raise_error(stream): + response = httpx.Response(status_code=200, content=stream) + stream_obj = Stream[StubStreamObject](response=response, cast_to=StubStreamObject) + + with pytest.raises(StreamingDecodeError): + for _ in stream_obj: + pass From 616cdef77b97262efc864846d3ba6262b4abe67f Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Sun, 19 May 2024 08:24:53 +0300 Subject: [PATCH 08/14] test: Added integration test for streaming --- .../studio/resources/chat/chat_completions.py | 3 +-- ai21/models/chat/__init__.py | 11 +++++++++- .../clients/studio/test_chat_completions.py | 21 +++++++++++++++++-- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/ai21/clients/studio/resources/chat/chat_completions.py b/ai21/clients/studio/resources/chat/chat_completions.py index 65cc2d40..cddc8cb3 100644 --- a/ai21/clients/studio/resources/chat/chat_completions.py +++ b/ai21/clients/studio/resources/chat/chat_completions.py @@ -4,8 +4,7 @@ from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import ChatMessage as J2ChatMessage -from ai21.models.chat import ChatMessage, ChatCompletionResponse -from ai21.models.chat.chat_completion_chunk import ChatCompletionChunk +from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk from ai21.stream import Stream from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given diff --git a/ai21/models/chat/__init__.py b/ai21/models/chat/__init__.py index 0fb4df66..d9332ff1 100644 --- a/ai21/models/chat/__init__.py +++ b/ai21/models/chat/__init__.py @@ -4,5 +4,14 @@ from .chat_completion_response import ChatCompletionResponseChoice from .chat_message import ChatMessage from .role_type import RoleType as RoleType +from .chat_completion_chunk import ChatCompletionChunk, ChoicesChunk, ChoiceDelta -__all__ = ["ChatCompletionResponse", "ChatCompletionResponseChoice", "ChatMessage", "RoleType"] +__all__ = [ + "ChatCompletionResponse", + "ChatCompletionResponseChoice", + "ChatMessage", + "RoleType", + "ChatCompletionChunk", + "ChoicesChunk", + "ChoiceDelta", +] diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index 684f0aa5..d564aeac 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -1,8 +1,7 @@ from ai21 import AI21Client -from ai21.models.chat import ChatMessage, ChatCompletionResponse +from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk, ChoicesChunk, ChoiceDelta from ai21.models import RoleType - _MODEL = "jamba-instruct-preview" _MESSAGES = [ ChatMessage( @@ -50,3 +49,21 @@ def test_chat_completion__with_n_param__should_return_n_choices(): for choice in response.choices: assert choice.message.content assert choice.message.role + + +def test_chat_completion__when_stream__should_return_chunks(): + messages = _MESSAGES + + client = AI21Client() + + response = client.chat.completions.create( + model=_MODEL, + messages=messages, + temperature=0, + stream=True, + ) + + for chunk in response: + assert isinstance(chunk, ChatCompletionChunk) + assert isinstance(chunk.choices[0], ChoicesChunk) + assert isinstance(chunk.choices[0].delta, ChoiceDelta) From 4ed66ccaa15ba759f49eaf7fd2770e175b7cd37b Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Sun, 19 May 2024 15:13:26 +0300 Subject: [PATCH 09/14] fix: Added enter and close to stream --- ai21/stream.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ai21/stream.py b/ai21/stream.py index bd324000..de3b01d3 100644 --- a/ai21/stream.py +++ b/ai21/stream.py @@ -1,9 +1,11 @@ from __future__ import annotations import json +from types import TracebackType from typing import TypeVar, Generic, Iterator, Optional import httpx +from typing_extensions import Self from ai21.errors import StreamingDecodeError @@ -47,6 +49,20 @@ def __stream__(self) -> Iterator[_T]: except json.JSONDecodeError: raise StreamingDecodeError(chunk) + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def close(self): + self.response.close() + class _SSEDecoder: def iter(self, iterator: Iterator[str]): From 9b6e8e58fc6f19559d9dbbc5d72bb6aa631239f6 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Mon, 20 May 2024 12:00:23 +0300 Subject: [PATCH 10/14] fix: Uncomment chat completions test --- tests/integration_tests/clients/test_studio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index 2d2d78f8..42e7abf4 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -28,7 +28,7 @@ ("summarize_by_segment.py",), ("tokenization.py",), ("chat/chat_completions.py",), - # ("chat/stream_chat_completions.py",), # Uncomment when streaming is supported in production + ("chat/stream_chat_completions.py",), # ("custom_model.py", ), # ('custom_model_completion.py', ), # ("dataset.py", ), From eddfa97d51dbf9f170ea7b22f521a7f5cd8c0204 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Mon, 20 May 2024 12:03:32 +0300 Subject: [PATCH 11/14] fix: poetry.lock --- poetry.lock | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index 835c27ae..3ab79b3f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1215,6 +1215,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1222,8 +1223,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -1240,6 +1248,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -1247,6 +1256,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -1333,24 +1343,24 @@ python-versions = ">=3.6" files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d92f81886165cb14d7b067ef37e142256f1c6a90a65cd156b063a43da1708cfd"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:aa2267c6a303eb483de8d02db2871afb5c5fc15618d894300b88958f729ad74f"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b5edda50e5e9e15e54a6a8a0070302b00c518a9d32accc2346ad6c984aacd279"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:1707814f0d9791df063f8c19bb51b0d1278b8e9a2353abbb676c2f685dee6afe"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:7048c338b6c86627afb27faecf418768acb6331fc24cfa56c93e8c9780f815fa"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:1dc67314e7e1086c9fdf2680b7b6c2be1c0d8e3a8279f2e993ca2a7545fecf62"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"}, @@ -1358,7 +1368,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3fcc54cb0c8b811ff66082de1680b4b14cf8a81dce0d4fbf665c2265a81e07a1"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:77159f5d5b5c14f7c34073862a6b7d34944075d9f93e681638f6d753606c6ce6"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"}, @@ -1366,7 +1376,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:665f58bfd29b167039f714c6998178d27ccd83984084c286110ef26b230f259f"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:305889baa4043a09e5b76f8e2a51d4ffba44259f6b4c72dec8ca56207d9c6fe1"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"}, @@ -1374,7 +1384,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:9eb5dee2772b0f704ca2e45b1713e4e5198c18f515b52743576d196348f374d3"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:a1a45e0bb052edf6a1d3a93baef85319733a888363938e1fc9924cb00c8df24c"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"}, @@ -1870,4 +1880,4 @@ aws = ["boto3"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "c6c474d713d3660255aade619131c42d2c57ba74d992b7e20be02275601a5b48" +content-hash = "4f44e31b57439733446c5a31225f31cf3e3a5d4de8b8ee57de5ae62b99bbf076" From 4bad319b261bb472778beab135eb39dd014f2a95 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Mon, 20 May 2024 12:05:33 +0300 Subject: [PATCH 12/14] fix: poetry.lock --- poetry.lock | 78 ----------------------------------------------------- 1 file changed, 78 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9a8bd2bb..3ab79b3f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -65,28 +65,6 @@ files = [ [package.dependencies] cryptography = "*" -[[package]] -name = "anyio" -version = "4.3.0" -description = "High level compatibility layer for multiple asynchronous event loop implementations" -optional = false -python-versions = ">=3.8" -files = [ - {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, - {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, -] - -[package.dependencies] -exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} -idna = ">=2.8" -sniffio = ">=1.1" -typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} - -[package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] -trio = ["trio (>=0.23)"] - [[package]] name = "black" version = "24.4.2" @@ -627,62 +605,6 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] -[[package]] -name = "h11" -version = "0.14.0" -description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -optional = false -python-versions = ">=3.7" -files = [ - {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, - {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, -] - -[[package]] -name = "httpcore" -version = "1.0.5" -description = "A minimal low-level HTTP client." -optional = false -python-versions = ">=3.8" -files = [ - {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, - {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, -] - -[package.dependencies] -certifi = "*" -h11 = ">=0.13,<0.15" - -[package.extras] -asyncio = ["anyio (>=4.0,<5.0)"] -http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.26.0)"] - -[[package]] -name = "httpx" -version = "0.27.0" -description = "The next generation HTTP client." -optional = false -python-versions = ">=3.8" -files = [ - {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, - {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, -] - -[package.dependencies] -anyio = "*" -certifi = "*" -httpcore = "==1.*" -idna = "*" -sniffio = "*" - -[package.extras] -brotli = ["brotli", "brotlicffi"] -cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] -http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] - [[package]] name = "huggingface-hub" version = "0.23.0" From de11631745a2228c0f4f238d821a69197740a00b Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Mon, 20 May 2024 12:08:56 +0300 Subject: [PATCH 13/14] fix: Uncomment test case --- tests/integration_tests/clients/test_studio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index 42e7abf4..6d0c948e 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -48,7 +48,7 @@ "when_summarize_by_segment__should_return_ok", "when_tokenization__should_return_ok", "when_chat_completions__should_return_ok", - # "when_stream_chat_completions__should_return_ok", + "when_stream_chat_completions__should_return_ok", # "when_custom_model__should_return_ok", # "when_custom_model_completion__should_return_ok", # "when_dataset__should_return_ok", From bfd03aa82c0ead9e5068b6034fbe9a7c3c9273a8 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Mon, 20 May 2024 12:16:24 +0300 Subject: [PATCH 14/14] fix: Removed unused json_to_response --- ai21/clients/common/custom_model_base.py | 3 --- ai21/clients/common/embed_base.py | 3 --- ai21/clients/common/improvements_base.py | 3 --- 3 files changed, 9 deletions(-) diff --git a/ai21/clients/common/custom_model_base.py b/ai21/clients/common/custom_model_base.py index 84980b28..776626d4 100644 --- a/ai21/clients/common/custom_model_base.py +++ b/ai21/clients/common/custom_model_base.py @@ -38,9 +38,6 @@ def list(self) -> List[CustomBaseModelResponse]: def get(self, resource_id: str) -> CustomBaseModelResponse: pass - def _json_to_response(self, json: Dict[str, Any]) -> CustomBaseModelResponse: - return CustomBaseModelResponse.from_dict(json) - def _create_body( self, dataset_id: str, diff --git a/ai21/clients/common/embed_base.py b/ai21/clients/common/embed_base.py index be4e70d4..6e16a795 100644 --- a/ai21/clients/common/embed_base.py +++ b/ai21/clients/common/embed_base.py @@ -19,8 +19,5 @@ def create(self, texts: List[str], *, type: Optional[EmbedType] = None, **kwargs """ pass - def _json_to_response(self, json: Dict[str, Any]) -> EmbedResponse: - return EmbedResponse.from_dict(json) - def _create_body(self, texts: List[str], type: Optional[str], **kwargs) -> Dict[str, Any]: return {"texts": texts, "type": type, **kwargs} diff --git a/ai21/clients/common/improvements_base.py b/ai21/clients/common/improvements_base.py index 1dfafe0f..75ac7306 100644 --- a/ai21/clients/common/improvements_base.py +++ b/ai21/clients/common/improvements_base.py @@ -18,8 +18,5 @@ def create(self, text: str, types: List[ImprovementType], **kwargs) -> Improveme """ pass - def _json_to_response(self, json: Dict[str, Any]) -> ImprovementsResponse: - return ImprovementsResponse.from_dict(json) - def _create_body(self, text: str, types: List[str], **kwargs) -> Dict[str, Any]: return {"text": text, "types": types, **kwargs}