From b0dc5a8fc84ab2a33c78a33630a80ff6591b102a Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Sun, 28 Jul 2024 14:31:43 +0300 Subject: [PATCH 01/11] feat: add stream support for bedrock --- README.md | 26 ++++ ai21/clients/bedrock/_stream_decoder.py | 68 +++++++++++ ai21/clients/bedrock/ai21_bedrock_client.py | 13 +- .../studio/resources/studio_resource.py | 19 ++- ai21/errors.py | 4 +- ai21/http_client/async_http_client.py | 3 + ai21/http_client/http_client.py | 3 + ai21/stream/async_stream.py | 7 +- ai21/stream/stream.py | 7 +- ai21/stream/stream_commons.py | 11 +- .../chat/async_stream_chat_completions.py | 25 ++++ .../bedrock/chat/stream_chat_completions.py | 22 ++++ .../integration_tests/clients/test_bedrock.py | 4 + tests/unittests/test_aws_stream_decoder.py | 114 ++++++++++++++++++ 14 files changed, 309 insertions(+), 17 deletions(-) create mode 100644 ai21/clients/bedrock/_stream_decoder.py create mode 100644 examples/bedrock/chat/async_stream_chat_completions.py create mode 100644 examples/bedrock/chat/stream_chat_completions.py create mode 100644 tests/unittests/test_aws_stream_decoder.py diff --git a/README.md b/README.md index fd920bad..522fb358 100644 --- a/README.md +++ b/README.md @@ -502,6 +502,32 @@ response = client.chat.completions.create( ) ``` +#### Stream + +```python +from ai21 import AI21BedrockClient, BedrockModelID +from ai21.models.chat import ChatMessage + +system = "You're a support engineer in a SaaS company" +messages = [ + ChatMessage(content=system, role="system"), + ChatMessage(content="Hello, I need help with a signup process.", role="user"), + ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), + ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), +] + +client = AI21BedrockClient() + +response = client.chat.completions.create( + messages=messages, + model=BedrockModelID.JAMBA_INSTRUCT_V1, + stream=True, +) + +for chunk in response: + print(chunk.choices[0].message.content, end="") +``` + #### Async ```python diff --git a/ai21/clients/bedrock/_stream_decoder.py b/ai21/clients/bedrock/_stream_decoder.py new file mode 100644 index 00000000..3d0ea5f4 --- /dev/null +++ b/ai21/clients/bedrock/_stream_decoder.py @@ -0,0 +1,68 @@ +from functools import lru_cache +from typing import Iterator, AsyncIterator + +import httpx +from botocore.model import Shape +from botocore.eventstream import EventStreamMessage + +from ai21.errors import StreamingDecodeError + + +@lru_cache(maxsize=None) +def get_response_stream_shape() -> Shape: + from botocore.model import ServiceModel + from botocore.loaders import Loader + + loader = Loader() + bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2") + bedrock_service_model = ServiceModel(bedrock_service_dict) + return bedrock_service_model.shape_for("ResponseStream") + + +class AWSEventStreamDecoder: + def __init__(self) -> None: + from botocore.parsers import EventStreamJSONParser + + self.parser = EventStreamJSONParser() + + def iter(self, response: httpx.Response) -> Iterator[str]: + """Given an iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + for chunk in response.iter_bytes(): + try: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + yield message + except Exception as e: + raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) + + async def aiter(self, response: httpx.Response) -> AsyncIterator[str]: + """Given an async iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + async for chunk in response.aiter_bytes(): + try: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + yield message + except Exception as e: + raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) + + def _parse_message_from_event(self, event: EventStreamMessage) -> str | None: + response_dict = event.to_response_dict() + parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) + if response_dict["status_code"] != 200: + raise ValueError(f"Bad response code, expected 200: {response_dict}") + + chunk = parsed_response.get("chunk") + if not chunk: + return None + + return chunk.get("bytes").decode() # type: ignore[no-any-return] diff --git a/ai21/clients/bedrock/ai21_bedrock_client.py b/ai21/clients/bedrock/ai21_bedrock_client.py index af893e66..a9dab046 100644 --- a/ai21/clients/bedrock/ai21_bedrock_client.py +++ b/ai21/clients/bedrock/ai21_bedrock_client.py @@ -9,6 +9,7 @@ from ai21 import AI21APIError from ai21.ai21_env_config import AI21EnvConfig from ai21.clients.aws.aws_authorization import AWSAuthorization +from ai21.clients.bedrock._stream_decoder import AWSEventStreamDecoder from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion, AsyncStudioCompletion from ai21.errors import AccessDenied, NotFound, APITimeoutError, ModelErrorException @@ -48,12 +49,12 @@ def _prepare_options(self, options: RequestOptions) -> RequestOptions: model = body.pop("model", None) stream = body.pop("stream", False) - + endpoint = "invoke" if stream: - _logger.warning("Field stream is not supported. Ignoring it.") + endpoint = "invoke-with-response-stream" # When stream is supported we would need to update this section and the URL - url = f"{options.url}/model/{model}/invoke" + url = f"{options.url}/model/{model}/{endpoint}" headers = self._prepare_headers(url=url, body=body) return options.replace( @@ -119,6 +120,9 @@ def _build_request(self, options: RequestOptions) -> httpx.Request: def _prepare_url(self, options: RequestOptions) -> str: return options.url + def _get_streaming_decoder(self) -> AWSEventStreamDecoder: + return AWSEventStreamDecoder() + class AsyncAI21BedrockClient(AsyncAI21HTTPClient, BaseBedrockClient): def __init__( @@ -167,3 +171,6 @@ def _build_request(self, options: RequestOptions) -> httpx.Request: def _prepare_url(self, options: RequestOptions) -> str: return options.url + + def _get_streaming_decoder(self) -> AWSEventStreamDecoder: + return AWSEventStreamDecoder() diff --git a/ai21/clients/studio/resources/studio_resource.py b/ai21/clients/studio/resources/studio_resource.py index b6386ec3..372b9f32 100644 --- a/ai21/clients/studio/resources/studio_resource.py +++ b/ai21/clients/studio/resources/studio_resource.py @@ -17,10 +17,11 @@ def _cast_response( response_cls: Optional[ResponseT], stream_cls: Optional[AsyncStreamT] = None, stream: bool = False, + streaming_decoder: Optional[Any] = None, ) -> ResponseT | AsyncStreamT | None: if stream and stream_cls is not None: cast_to = extract_type(stream_cls) - return stream_cls(cast_to=cast_to, response=response) + return stream_cls(cast_to=cast_to, response=response, streaming_decoder=streaming_decoder) if response_cls is None: return None @@ -63,7 +64,13 @@ def _post( files=files, ) - return _cast_response(stream=stream, response=response, response_cls=response_cls, stream_cls=stream_cls) + return _cast_response( + stream=stream, + response=response, + response_cls=response_cls, + stream_cls=stream_cls, + streaming_decoder=self._client._get_streaming_decoder(), + ) def _get( self, path: str, response_cls: Optional[ResponseT] = None, params: Optional[Dict[str, Any]] = None @@ -108,7 +115,13 @@ async def _post( files=files, ) - return _cast_response(stream=stream, response=response, response_cls=response_cls, stream_cls=stream_cls) + return _cast_response( + stream=stream, + response=response, + response_cls=response_cls, + stream_cls=stream_cls, + streaming_decoder=self._client._get_streaming_decoder(), + ) async def _get( self, path: str, response_cls: Optional[ResponseT] = None, params: Optional[Dict[str, Any]] = None diff --git a/ai21/errors.py b/ai21/errors.py index 83091ed3..f295fb13 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -95,8 +95,10 @@ def __init__(self, key: str): class StreamingDecodeError(AI21Error): - def __init__(self, chunk: str): + def __init__(self, chunk: str, error_message: Optional[str] = None): message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format" + if error_message: + message += f". Error: {error_message}" super().__init__(message) diff --git a/ai21/http_client/async_http_client.py b/ai21/http_client/async_http_client.py index 0f7281a6..673e4f1e 100644 --- a/ai21/http_client/async_http_client.py +++ b/ai21/http_client/async_http_client.py @@ -119,3 +119,6 @@ def _init_client(self, client: Optional[httpx.AsyncClient]) -> httpx.AsyncClient return httpx.AsyncClient(transport=_requests_retry_async_session(retries=self._num_retries)) return httpx.AsyncClient() + + def _get_streaming_decoder(self): + return None diff --git a/ai21/http_client/http_client.py b/ai21/http_client/http_client.py index 6ad1dd49..93f0db09 100644 --- a/ai21/http_client/http_client.py +++ b/ai21/http_client/http_client.py @@ -117,3 +117,6 @@ def _init_client(self, client: Optional[httpx.Client]) -> httpx.Client: return httpx.Client(transport=_requests_retry_session(retries=self._num_retries)) return httpx.Client() + + def _get_streaming_decoder(self): + return None diff --git a/ai21/stream/async_stream.py b/ai21/stream/async_stream.py index 9203938f..1358c697 100644 --- a/ai21/stream/async_stream.py +++ b/ai21/stream/async_stream.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generic, AsyncIterator +from typing import Generic, AsyncIterator, Optional, Any from typing_extensions import Self from types import TracebackType @@ -17,10 +17,11 @@ def __init__( *, cast_to: type[_T], response: httpx.Response, + streaming_decoder: Optional[Any] = None, ): self.response = response self.cast_to = cast_to - self._decoder = _SSEDecoder() + self._decoder = streaming_decoder or _SSEDecoder() self._iterator = self.__stream__() async def __anext__(self) -> _T: @@ -31,7 +32,7 @@ async def __aiter__(self) -> AsyncIterator[_T]: yield item async def __stream__(self) -> AsyncIterator[_T]: - iterator = self._decoder.aiter(self.response.aiter_lines()) + iterator = self._decoder.aiter(self.response) async for chunk in iterator: if chunk.endswith(_SSE_DONE_MSG): break diff --git a/ai21/stream/stream.py b/ai21/stream/stream.py index e4df59cd..bbc52ded 100644 --- a/ai21/stream/stream.py +++ b/ai21/stream/stream.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import TracebackType -from typing import Generic, Iterator +from typing import Generic, Iterator, Optional, Any import httpx from typing_extensions import Self @@ -17,10 +17,11 @@ def __init__( *, cast_to: type[_T], response: httpx.Response, + streaming_decoder: Optional[Any] = None, ): self.response = response self.cast_to = cast_to - self._decoder = _SSEDecoder() + self._decoder = streaming_decoder or _SSEDecoder() self._iterator = self.__stream__() def __next__(self) -> _T: @@ -31,7 +32,7 @@ def __iter__(self) -> Iterator[_T]: yield item def __stream__(self) -> Iterator[_T]: - iterator = self._decoder.iter(self.response.iter_lines()) + iterator = self._decoder.iter(self.response) for chunk in iterator: if chunk.endswith(_SSE_DONE_MSG): break diff --git a/ai21/stream/stream_commons.py b/ai21/stream/stream_commons.py index 3cda37e0..13e84623 100644 --- a/ai21/stream/stream_commons.py +++ b/ai21/stream/stream_commons.py @@ -2,6 +2,9 @@ import json from typing import TypeVar, Iterator, AsyncIterator, Optional + +import httpx + from ai21.errors import StreamingDecodeError @@ -22,16 +25,16 @@ def get_stream_message(chunk: str, cast_to: type[_T]) -> Iterator[_T] | AsyncIte class _SSEDecoder: - def iter(self, iterator: Iterator[str]): - for line in iterator: + def iter(self, response: httpx.Response): + for line in response.iter_lines(): line = line.strip() decoded_line = self._decode(line) if decoded_line is not None: yield decoded_line - async def aiter(self, iterator: AsyncIterator[str]): - async for line in iterator: + async def aiter(self, response: httpx.Response): + async for line in response.aiter_lines(): line = line.strip() decoded_line = self._decode(line) diff --git a/examples/bedrock/chat/async_stream_chat_completions.py b/examples/bedrock/chat/async_stream_chat_completions.py new file mode 100644 index 00000000..fbb486d2 --- /dev/null +++ b/examples/bedrock/chat/async_stream_chat_completions.py @@ -0,0 +1,25 @@ +import asyncio +from ai21 import AsyncAI21BedrockClient, BedrockModelID +from ai21.models.chat import ChatMessage + +client = AsyncAI21BedrockClient(region="us-east-1") # region is optional, as you can use the env variable instead + +messages = [ + ChatMessage(content="You are a helpful assistant", role="system"), + ChatMessage(content="What is the meaning of life?", role="user"), +] + + +async def main(): + response = await client.chat.completions.create( + messages=messages, + model=BedrockModelID.JAMBA_INSTRUCT_V1, + max_tokens=100, + stream=True, + ) + + async for chunk in response: + print(chunk.choices[0].message.content, end="") + + +asyncio.run(main()) diff --git a/examples/bedrock/chat/stream_chat_completions.py b/examples/bedrock/chat/stream_chat_completions.py new file mode 100644 index 00000000..50243504 --- /dev/null +++ b/examples/bedrock/chat/stream_chat_completions.py @@ -0,0 +1,22 @@ +from ai21 import AI21BedrockClient, BedrockModelID +from ai21.models.chat import ChatMessage + +system = "You're a support engineer in a SaaS company" +messages = [ + ChatMessage(content=system, role="system"), + ChatMessage(content="Hello, I need help with a signup process.", role="user"), + ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), + ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), +] + +client = AI21BedrockClient() + +response = client.chat.completions.create( + messages=messages, + model=BedrockModelID.JAMBA_INSTRUCT_V1, + max_tokens=100, + stream=True, +) + +for chunk in response: + print(chunk.choices[0].message.content, end="") diff --git a/tests/integration_tests/clients/test_bedrock.py b/tests/integration_tests/clients/test_bedrock.py index 6b245a08..0ecedeba 100644 --- a/tests/integration_tests/clients/test_bedrock.py +++ b/tests/integration_tests/clients/test_bedrock.py @@ -21,13 +21,17 @@ ("completion.py",), ("async_completion.py",), ("chat/chat_completions.py",), + # ("chat/stream_chat_completions.py",), ("chat/async_chat_completions.py",), + # ("chat/stream)async_chat_completions.py",), ], ids=[ "when_completion__should_return_ok", "when_async_completion__should_return_ok", "when_chat_completions__should_return_ok", + # "when_stream_chat_completions__should_return_ok", "when_async_chat_completions__should_return_ok", + # "when_stream_async_chat_completions__should_return_ok", ], ) def test_bedrock(test_file_name: str): diff --git a/tests/unittests/test_aws_stream_decoder.py b/tests/unittests/test_aws_stream_decoder.py new file mode 100644 index 00000000..f7314d2c --- /dev/null +++ b/tests/unittests/test_aws_stream_decoder.py @@ -0,0 +1,114 @@ +from dataclasses import dataclass +from typing import AsyncIterable, Iterable, Optional, List + +import httpx +import pytest + +from ai21.clients.bedrock._stream_decoder import AWSEventStreamDecoder +from ai21.errors import StreamingDecodeError +from ai21.stream.async_stream import AsyncStream +from ai21.stream.stream import Stream + + +@dataclass +class TestChoiceDelta: + content: Optional[str] = None + role: Optional[str] = None + + +@dataclass +class TestChoicesChunk: + index: int + message: TestChoiceDelta + finish_reason: Optional[str] = None + + +@dataclass +class TestChatCompletionChunk: + choices: List[TestChoicesChunk] + + +def byte_stream() -> Iterable[bytes]: + for i in range(10): + yield ( + b"\x00\x00\x01\x0b\x00\x00\x00K8\xa0\xa5\xc5\x0b:event-type\x07\x00\x05chunk\r:content-type\x07\x00" + b"\x10application/json\r:message-type\x07\x00\x05event{" + b'"bytes":"eyJjaG9pY2VzIjpbeyJpbmRleCI6MCwibWVzc2FnZSI6eyJyb2xlIjoiYXNzaXN0YW50IiwiY29udGVudCI6Ikkif' + b'Swic3RvcF9yZWFzb24iOm51bGx9XX0=","p":"abcdefghijklmnopqrstuvwxyzABCDEFGHIJK"}\x11\x061?' + ) + + +async def async_byte_stream() -> AsyncIterable[bytes]: + for i in range(10): + yield ( + b"\x00\x00\x01\x0b\x00\x00\x00K8\xa0\xa5\xc5\x0b:event-type\x07\x00\x05chunk\r:content-type\x07\x00" + b"\x10application/json\r:message-type\x07\x00\x05event{" + b'"bytes":"eyJjaG9pY2VzIjpbeyJpbmRleCI6MCwibWVzc2FnZSI6eyJyb2xlIjoiYXNzaXN0YW50IiwiY29udGVudCI6Ikkif' + b'Swic3RvcF9yZWFzb24iOm51bGx9XX0=","p":"abcdefghijklmnopqrstuvwxyzABCDEFGHIJK"}\x11\x061?' + ) + + +def byte_bad_stream_json_format() -> AsyncIterable[bytes]: + msg = "data: not a json format\r\n" + yield msg.encode("utf-8") + + +async def async_byte_bad_stream_json_format() -> AsyncIterable[bytes]: + msg = "data: not a json format\r\n" + yield msg.encode("utf-8") + + +def test_stream_object_when_json_string_ok__should_be_ok(): + stream = byte_stream() + response = httpx.Response(status_code=200, content=stream) + stream_obj = Stream[TestChatCompletionChunk]( + response=response, cast_to=TestChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + ) + + chunk_counter = 0 + for i, chunk in enumerate(stream_obj): + assert isinstance(chunk, TestChatCompletionChunk) + chunk_counter += 1 + + assert chunk_counter == 10 + + +@pytest.mark.asyncio +async def test_async_stream_object_when_json_string_ok__should_be_ok(): + stream = async_byte_stream() + response = httpx.Response(status_code=200, content=stream) + stream_obj = AsyncStream[TestChatCompletionChunk]( + response=response, cast_to=TestChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + ) + + chunk_counter = 0 + async for chunk in stream_obj: + assert isinstance(chunk, TestChatCompletionChunk) + chunk_counter += 1 + + assert chunk_counter == 10 + + +def test_stream_object_when_bad_json__should_raise_error(): + stream = byte_bad_stream_json_format() + response = httpx.Response(status_code=200, content=stream) + stream_obj = Stream[TestChatCompletionChunk]( + response=response, cast_to=TestChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + ) + + with pytest.raises(StreamingDecodeError): + for _ in stream_obj: + pass + + +@pytest.mark.asyncio +async def test_async_stream_object_when_bad_json__should_raise_error(): + stream = async_byte_bad_stream_json_format() + response = httpx.Response(status_code=200, content=stream) + stream_obj = AsyncStream[TestChatCompletionChunk]( + response=response, cast_to=TestChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + ) + + with pytest.raises(StreamingDecodeError): + async for _ in stream_obj: + pass From 6b06c8b1603fd02702166e7bcaf5f75048ce121a Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Sun, 28 Jul 2024 15:09:59 +0300 Subject: [PATCH 02/11] chore: fix error --- ai21/clients/bedrock/_stream_decoder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ai21/clients/bedrock/_stream_decoder.py b/ai21/clients/bedrock/_stream_decoder.py index 3d0ea5f4..5a81a460 100644 --- a/ai21/clients/bedrock/_stream_decoder.py +++ b/ai21/clients/bedrock/_stream_decoder.py @@ -1,9 +1,11 @@ +from __future__ import annotations + from functools import lru_cache from typing import Iterator, AsyncIterator import httpx -from botocore.model import Shape from botocore.eventstream import EventStreamMessage +from botocore.model import Shape from ai21.errors import StreamingDecodeError From f63df64a53fdec202103e032ce354cbee66246e7 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 29 Jul 2024 14:06:04 +0300 Subject: [PATCH 03/11] chore: fix stream examples --- examples/bedrock/chat/async_stream_chat_completions.py | 2 +- examples/bedrock/chat/stream_chat_completions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/bedrock/chat/async_stream_chat_completions.py b/examples/bedrock/chat/async_stream_chat_completions.py index fbb486d2..e8b1136b 100644 --- a/examples/bedrock/chat/async_stream_chat_completions.py +++ b/examples/bedrock/chat/async_stream_chat_completions.py @@ -19,7 +19,7 @@ async def main(): ) async for chunk in response: - print(chunk.choices[0].message.content, end="") + print(chunk.choices[0].delta.content, end="") asyncio.run(main()) diff --git a/examples/bedrock/chat/stream_chat_completions.py b/examples/bedrock/chat/stream_chat_completions.py index 50243504..d27263ad 100644 --- a/examples/bedrock/chat/stream_chat_completions.py +++ b/examples/bedrock/chat/stream_chat_completions.py @@ -19,4 +19,4 @@ ) for chunk in response: - print(chunk.choices[0].message.content, end="") + print(chunk.choices[0].delta.content, end="") From f778eea4251bed6bc0dce04a3cc9ff8785eebad7 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 29 Jul 2024 16:57:41 +0300 Subject: [PATCH 04/11] fix: add fix for last chunk on bedrock, add tests --- ai21/clients/bedrock/_stream_decoder.py | 27 +++++++++++ ai21/models/ai21_base_model.py | 5 ++ .../clients/bedrock/test_chat_completions.py | 48 +++++++++++++++++++ .../integration_tests/clients/test_bedrock.py | 8 ++-- 4 files changed, 84 insertions(+), 4 deletions(-) create mode 100644 tests/integration_tests/clients/bedrock/test_chat_completions.py diff --git a/ai21/clients/bedrock/_stream_decoder.py b/ai21/clients/bedrock/_stream_decoder.py index 5a81a460..f662f84c 100644 --- a/ai21/clients/bedrock/_stream_decoder.py +++ b/ai21/clients/bedrock/_stream_decoder.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from functools import lru_cache from typing import Iterator, AsyncIterator @@ -32,12 +33,22 @@ def iter(self, response: httpx.Response) -> Iterator[str]: from botocore.eventstream import EventStreamBuffer event_stream_buffer = EventStreamBuffer() + last_message = None for chunk in response.iter_bytes(): try: event_stream_buffer.add_data(chunk) for event in event_stream_buffer: message = self._parse_message_from_event(event) if message: + # START TODO: remove the following conditions once response from Bedrock changes + if last_message is not None: + message = self._create_bedrock_last_chunk( + chunk=last_message, bedrock_metrics_message=message + ) + if '"finish_reason":null' not in message and last_message is None: + last_message = message + continue + # END TODO yield message except Exception as e: raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) @@ -47,12 +58,22 @@ async def aiter(self, response: httpx.Response) -> AsyncIterator[str]: from botocore.eventstream import EventStreamBuffer event_stream_buffer = EventStreamBuffer() + last_message = None async for chunk in response.aiter_bytes(): try: event_stream_buffer.add_data(chunk) for event in event_stream_buffer: message = self._parse_message_from_event(event) if message: + # START TODO: remove the following conditions once response from Bedrock changes + if last_message is not None: + message = self._create_bedrock_last_chunk( + chunk=last_message, bedrock_metrics_message=message + ) + if '"finish_reason":null' not in message and last_message is None: + last_message = message + continue + # END TODO yield message except Exception as e: raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) @@ -68,3 +89,9 @@ def _parse_message_from_event(self, event: EventStreamMessage) -> str | None: return None return chunk.get("bytes").decode() # type: ignore[no-any-return] + + def _create_bedrock_last_chunk(self, chunk: str, bedrock_metrics_message: str) -> str: + chunk_dict = json.loads(chunk) + bedrock_metrics_dict = json.loads(bedrock_metrics_message) + chunk_dict = {**chunk_dict, **bedrock_metrics_dict} + return json.dumps(chunk_dict) diff --git a/ai21/models/ai21_base_model.py b/ai21/models/ai21_base_model.py index 3b67fd12..9c237ca9 100644 --- a/ai21/models/ai21_base_model.py +++ b/ai21/models/ai21_base_model.py @@ -5,17 +5,22 @@ from ai21.models._pydantic_compatibility import _to_dict, _to_json, _from_dict, _from_json, IS_PYDANTIC_V2 +if not IS_PYDANTIC_V2: + from pydantic import Extra + class AI21BaseModel(BaseModel): if IS_PYDANTIC_V2: model_config = ConfigDict( populate_by_name=True, protected_namespaces=(), + extra="allow", ) else: class Config: allow_population_by_field_name = True + extra = Extra.allow def to_dict(self, **kwargs) -> Dict[str, Any]: warnings.warn( diff --git a/tests/integration_tests/clients/bedrock/test_chat_completions.py b/tests/integration_tests/clients/bedrock/test_chat_completions.py new file mode 100644 index 00000000..d1132143 --- /dev/null +++ b/tests/integration_tests/clients/bedrock/test_chat_completions.py @@ -0,0 +1,48 @@ +import pytest + +from ai21 import AI21BedrockClient, AsyncAI21BedrockClient, BedrockModelID +from ai21.models._pydantic_compatibility import _to_dict +from ai21.models.chat import ChatMessage + +_SYSTEM_MSG = "You're a support engineer in a SaaS company" +_MESSAGES = [ + ChatMessage(content=_SYSTEM_MSG, role="system"), + ChatMessage(content="Hello, I need help with a signup process.", role="user"), +] + + +def test_chat_completions__when_stream__last_chunk_should_hold_bedrock_metrics(): + client = AI21BedrockClient() + response = client.chat.completions.create( + messages=_MESSAGES, + model=BedrockModelID.JAMBA_INSTRUCT_V1, + stream=True, + ) + + last_chunk = None + for chunk in response: + assert chunk.id is not None + assert chunk.choices is not None + last_chunk = chunk + + chunk_dict = _to_dict(last_chunk) + assert "amazon-bedrock-invocationMetrics" in chunk_dict + + +@pytest.mark.asyncio +async def test__async_chat_completions__when_stream__last_chunk_should_hold_bedrock_metrics(): + client = AsyncAI21BedrockClient() + response = await client.chat.completions.create( + messages=_MESSAGES, + model=BedrockModelID.JAMBA_INSTRUCT_V1, + stream=True, + ) + + last_chunk = None + async for chunk in response: + assert chunk.id is not None + assert chunk.choices is not None + last_chunk = chunk + + chunk_dict = _to_dict(last_chunk) + assert "amazon-bedrock-invocationMetrics" in chunk_dict diff --git a/tests/integration_tests/clients/test_bedrock.py b/tests/integration_tests/clients/test_bedrock.py index 0ecedeba..7c2dfff5 100644 --- a/tests/integration_tests/clients/test_bedrock.py +++ b/tests/integration_tests/clients/test_bedrock.py @@ -21,17 +21,17 @@ ("completion.py",), ("async_completion.py",), ("chat/chat_completions.py",), - # ("chat/stream_chat_completions.py",), + ("chat/stream_chat_completions.py",), ("chat/async_chat_completions.py",), - # ("chat/stream)async_chat_completions.py",), + ("chat/stream_async_chat_completions.py",), ], ids=[ "when_completion__should_return_ok", "when_async_completion__should_return_ok", "when_chat_completions__should_return_ok", - # "when_stream_chat_completions__should_return_ok", + "when_stream_chat_completions__should_return_ok", "when_async_chat_completions__should_return_ok", - # "when_stream_async_chat_completions__should_return_ok", + "when_stream_async_chat_completions__should_return_ok", ], ) def test_bedrock(test_file_name: str): From b01df9eadde72045c3dbbfa6a57b8d43819ad256 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 29 Jul 2024 17:06:59 +0300 Subject: [PATCH 05/11] test: fix unittests --- tests/unittests/test_aws_stream_decoder.py | 72 ++++++++++++---------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/tests/unittests/test_aws_stream_decoder.py b/tests/unittests/test_aws_stream_decoder.py index f7314d2c..0f0ede9d 100644 --- a/tests/unittests/test_aws_stream_decoder.py +++ b/tests/unittests/test_aws_stream_decoder.py @@ -1,50 +1,54 @@ -from dataclasses import dataclass -from typing import AsyncIterable, Iterable, Optional, List +from typing import AsyncIterable, Iterable import httpx import pytest from ai21.clients.bedrock._stream_decoder import AWSEventStreamDecoder from ai21.errors import StreamingDecodeError +from ai21.models.chat import ChatCompletionChunk from ai21.stream.async_stream import AsyncStream from ai21.stream.stream import Stream -@dataclass -class TestChoiceDelta: - content: Optional[str] = None - role: Optional[str] = None - - -@dataclass -class TestChoicesChunk: - index: int - message: TestChoiceDelta - finish_reason: Optional[str] = None - - -@dataclass -class TestChatCompletionChunk: - choices: List[TestChoicesChunk] +# @dataclass +# class TestChoiceDelta: +# content: Optional[str] = None +# role: Optional[str] = None +# +# +# @dataclass +# class TestChoicesChunk: +# index: int +# message: TestChoiceDelta +# finish_reason: Optional[str] = None +# +# +# @dataclass +# class TestChatCompletionChunk: +# choices: List[TestChoicesChunk] def byte_stream() -> Iterable[bytes]: for i in range(10): yield ( - b"\x00\x00\x01\x0b\x00\x00\x00K8\xa0\xa5\xc5\x0b:event-type\x07\x00\x05chunk\r:content-type\x07\x00" + b"\x00\x00\x01\x80\x00\x00\x00K\xfe\x96$F\x0b:event-type\x07\x00\x05chunk\r:content-type\x07\x00" b"\x10application/json\r:message-type\x07\x00\x05event{" - b'"bytes":"eyJjaG9pY2VzIjpbeyJpbmRleCI6MCwibWVzc2FnZSI6eyJyb2xlIjoiYXNzaXN0YW50IiwiY29udGVudCI6Ikkif' - b'Swic3RvcF9yZWFzb24iOm51bGx9XX0=","p":"abcdefghijklmnopqrstuvwxyzABCDEFGHIJK"}\x11\x061?' + b'"bytes":"eyJpZCI6ImNtcGwtOTgxZjdmMTc2YWQ0NDE0NTliOTRlNDVlZTI5MmEzMjEiLCJjaG9pY2VzIjpbeyJpbmRleC' + b"I6MCwiZGVsdGEiOnsicm9sZSI6ImFzc2lzdGFudCJ9LCJmaW5pc2hfcmVhc29uIjpudWxsfV0sInVzYWdlIjp7InByb21wd" + b'F90b2tlbnMiOjQ0LCJ0b3RhbF90b2tlbnMiOjQ0LCJjb21wbGV0aW9uX3Rva2VucyI6MH19","p":"abcdefghijklmnopq' + b'rstuv"}5\xca\xa7\x98' ) async def async_byte_stream() -> AsyncIterable[bytes]: for i in range(10): yield ( - b"\x00\x00\x01\x0b\x00\x00\x00K8\xa0\xa5\xc5\x0b:event-type\x07\x00\x05chunk\r:content-type\x07\x00" + b"\x00\x00\x01\x80\x00\x00\x00K\xfe\x96$F\x0b:event-type\x07\x00\x05chunk\r:content-type\x07\x00" b"\x10application/json\r:message-type\x07\x00\x05event{" - b'"bytes":"eyJjaG9pY2VzIjpbeyJpbmRleCI6MCwibWVzc2FnZSI6eyJyb2xlIjoiYXNzaXN0YW50IiwiY29udGVudCI6Ikkif' - b'Swic3RvcF9yZWFzb24iOm51bGx9XX0=","p":"abcdefghijklmnopqrstuvwxyzABCDEFGHIJK"}\x11\x061?' + b'"bytes":"eyJpZCI6ImNtcGwtOTgxZjdmMTc2YWQ0NDE0NTliOTRlNDVlZTI5MmEzMjEiLCJjaG9pY2VzIjpbeyJpbmRleC' + b"I6MCwiZGVsdGEiOnsicm9sZSI6ImFzc2lzdGFudCJ9LCJmaW5pc2hfcmVhc29uIjpudWxsfV0sInVzYWdlIjp7InByb21wd" + b'F90b2tlbnMiOjQ0LCJ0b3RhbF90b2tlbnMiOjQ0LCJjb21wbGV0aW9uX3Rva2VucyI6MH19","p":"abcdefghijklmnopq' + b'rstuv"}5\xca\xa7\x98' ) @@ -61,13 +65,13 @@ async def async_byte_bad_stream_json_format() -> AsyncIterable[bytes]: def test_stream_object_when_json_string_ok__should_be_ok(): stream = byte_stream() response = httpx.Response(status_code=200, content=stream) - stream_obj = Stream[TestChatCompletionChunk]( - response=response, cast_to=TestChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + stream_obj = Stream[ChatCompletionChunk]( + response=response, cast_to=ChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() ) chunk_counter = 0 for i, chunk in enumerate(stream_obj): - assert isinstance(chunk, TestChatCompletionChunk) + assert isinstance(chunk, ChatCompletionChunk) chunk_counter += 1 assert chunk_counter == 10 @@ -77,13 +81,13 @@ def test_stream_object_when_json_string_ok__should_be_ok(): async def test_async_stream_object_when_json_string_ok__should_be_ok(): stream = async_byte_stream() response = httpx.Response(status_code=200, content=stream) - stream_obj = AsyncStream[TestChatCompletionChunk]( - response=response, cast_to=TestChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + stream_obj = AsyncStream[ChatCompletionChunk]( + response=response, cast_to=ChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() ) chunk_counter = 0 async for chunk in stream_obj: - assert isinstance(chunk, TestChatCompletionChunk) + assert isinstance(chunk, ChatCompletionChunk) chunk_counter += 1 assert chunk_counter == 10 @@ -92,8 +96,8 @@ async def test_async_stream_object_when_json_string_ok__should_be_ok(): def test_stream_object_when_bad_json__should_raise_error(): stream = byte_bad_stream_json_format() response = httpx.Response(status_code=200, content=stream) - stream_obj = Stream[TestChatCompletionChunk]( - response=response, cast_to=TestChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + stream_obj = Stream[ChatCompletionChunk]( + response=response, cast_to=ChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() ) with pytest.raises(StreamingDecodeError): @@ -105,8 +109,8 @@ def test_stream_object_when_bad_json__should_raise_error(): async def test_async_stream_object_when_bad_json__should_raise_error(): stream = async_byte_bad_stream_json_format() response = httpx.Response(status_code=200, content=stream) - stream_obj = AsyncStream[TestChatCompletionChunk]( - response=response, cast_to=TestChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + stream_obj = AsyncStream[ChatCompletionChunk]( + response=response, cast_to=ChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() ) with pytest.raises(StreamingDecodeError): From 661c5f574a3128749fb8316e950e0b2602d8c8d0 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 29 Jul 2024 17:07:31 +0300 Subject: [PATCH 06/11] test: remove comments --- tests/unittests/test_aws_stream_decoder.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/unittests/test_aws_stream_decoder.py b/tests/unittests/test_aws_stream_decoder.py index 0f0ede9d..b255f8d6 100644 --- a/tests/unittests/test_aws_stream_decoder.py +++ b/tests/unittests/test_aws_stream_decoder.py @@ -10,24 +10,6 @@ from ai21.stream.stream import Stream -# @dataclass -# class TestChoiceDelta: -# content: Optional[str] = None -# role: Optional[str] = None -# -# -# @dataclass -# class TestChoicesChunk: -# index: int -# message: TestChoiceDelta -# finish_reason: Optional[str] = None -# -# -# @dataclass -# class TestChatCompletionChunk: -# choices: List[TestChoicesChunk] - - def byte_stream() -> Iterable[bytes]: for i in range(10): yield ( From ae2c38e36f68f97e2fff3983ab9df34ae6e63374 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 29 Jul 2024 17:19:34 +0300 Subject: [PATCH 07/11] test: fix integration test --- tests/integration_tests/clients/test_bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/clients/test_bedrock.py b/tests/integration_tests/clients/test_bedrock.py index 7c2dfff5..79298ed4 100644 --- a/tests/integration_tests/clients/test_bedrock.py +++ b/tests/integration_tests/clients/test_bedrock.py @@ -23,7 +23,7 @@ ("chat/chat_completions.py",), ("chat/stream_chat_completions.py",), ("chat/async_chat_completions.py",), - ("chat/stream_async_chat_completions.py",), + ("chat/async_stream_chat_completions.py",), ], ids=[ "when_completion__should_return_ok", From 236461016ac9994326336f7c3c7833dcc47d4bd7 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Sun, 4 Aug 2024 16:52:03 +0300 Subject: [PATCH 08/11] chore: cr comments --- ai21/clients/bedrock/_stream_decoder.py | 88 ++++++++----------- ai21/clients/bedrock/ai21_bedrock_client.py | 5 +- .../studio/resources/studio_resource.py | 3 +- ai21/errors.py | 4 +- ai21/http_client/async_http_client.py | 6 +- ai21/http_client/http_client.py | 6 +- ai21/models/ai21_base_model.py | 6 +- ai21/stream/stream_commons.py | 14 ++- .../clients/bedrock/test_chat_completions.py | 14 +-- 9 files changed, 69 insertions(+), 77 deletions(-) diff --git a/ai21/clients/bedrock/_stream_decoder.py b/ai21/clients/bedrock/_stream_decoder.py index f662f84c..0720bed2 100644 --- a/ai21/clients/bedrock/_stream_decoder.py +++ b/ai21/clients/bedrock/_stream_decoder.py @@ -5,10 +5,12 @@ from typing import Iterator, AsyncIterator import httpx -from botocore.eventstream import EventStreamMessage +from botocore.eventstream import EventStreamMessage, EventStreamBuffer from botocore.model import Shape +from botocore.parsers import EventStreamJSONParser from ai21.errors import StreamingDecodeError +from ai21.stream.stream_commons import SSEDecoderBase @lru_cache(maxsize=None) @@ -22,65 +24,39 @@ def get_response_stream_shape() -> Shape: return bedrock_service_model.shape_for("ResponseStream") -class AWSEventStreamDecoder: +class AWSEventStreamDecoder(SSEDecoderBase): def __init__(self) -> None: - from botocore.parsers import EventStreamJSONParser - - self.parser = EventStreamJSONParser() + self._parser = EventStreamJSONParser() def iter(self, response: httpx.Response) -> Iterator[str]: - """Given an iterator that yields lines, iterate over it & yield every event encountered""" - from botocore.eventstream import EventStreamBuffer - event_stream_buffer = EventStreamBuffer() - last_message = None + previous_item = None for chunk in response.iter_bytes(): - try: - event_stream_buffer.add_data(chunk) - for event in event_stream_buffer: - message = self._parse_message_from_event(event) - if message: - # START TODO: remove the following conditions once response from Bedrock changes - if last_message is not None: - message = self._create_bedrock_last_chunk( - chunk=last_message, bedrock_metrics_message=message - ) - if '"finish_reason":null' not in message and last_message is None: - last_message = message - continue - # END TODO - yield message - except Exception as e: - raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) + item = next(self._process_chunks(event_stream_buffer, chunk)) + # For Bedrock metering chunk: + if previous_item is not None: + item = self._build_last_chunk(last_model_chunk=previous_item, bedrock_metrics_chunk=item) + if '"finish_reason":null' not in item and previous_item is None: + previous_item = item + continue + yield item async def aiter(self, response: httpx.Response) -> AsyncIterator[str]: - """Given an async iterator that yields lines, iterate over it & yield every event encountered""" - from botocore.eventstream import EventStreamBuffer - event_stream_buffer = EventStreamBuffer() - last_message = None + previous_item = None async for chunk in response.aiter_bytes(): - try: - event_stream_buffer.add_data(chunk) - for event in event_stream_buffer: - message = self._parse_message_from_event(event) - if message: - # START TODO: remove the following conditions once response from Bedrock changes - if last_message is not None: - message = self._create_bedrock_last_chunk( - chunk=last_message, bedrock_metrics_message=message - ) - if '"finish_reason":null' not in message and last_message is None: - last_message = message - continue - # END TODO - yield message - except Exception as e: - raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) + item = next(self._process_chunks(event_stream_buffer, chunk)) + # For Bedrock metering chunk: + if previous_item is not None: + item = self._build_last_chunk(last_model_chunk=previous_item, bedrock_metrics_chunk=item) + if '"finish_reason":null' not in item and previous_item is None: + previous_item = item + continue + yield item def _parse_message_from_event(self, event: EventStreamMessage) -> str | None: response_dict = event.to_response_dict() - parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) + parsed_response = self._parser.parse(response_dict, get_response_stream_shape()) if response_dict["status_code"] != 200: raise ValueError(f"Bad response code, expected 200: {response_dict}") @@ -90,8 +66,18 @@ def _parse_message_from_event(self, event: EventStreamMessage) -> str | None: return chunk.get("bytes").decode() # type: ignore[no-any-return] - def _create_bedrock_last_chunk(self, chunk: str, bedrock_metrics_message: str) -> str: - chunk_dict = json.loads(chunk) - bedrock_metrics_dict = json.loads(bedrock_metrics_message) + def _build_last_chunk(self, last_model_chunk: str, bedrock_metrics_chunk: str) -> str: + chunk_dict = json.loads(last_model_chunk) + bedrock_metrics_dict = json.loads(bedrock_metrics_chunk) chunk_dict = {**chunk_dict, **bedrock_metrics_dict} return json.dumps(chunk_dict) + + def _process_chunks(self, event_stream_buffer, chunk): + try: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + yield message + except Exception as e: + raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) diff --git a/ai21/clients/bedrock/ai21_bedrock_client.py b/ai21/clients/bedrock/ai21_bedrock_client.py index a9dab046..f12a2c99 100644 --- a/ai21/clients/bedrock/ai21_bedrock_client.py +++ b/ai21/clients/bedrock/ai21_bedrock_client.py @@ -43,6 +43,7 @@ def _handle_bedrock_error(aws_error: AI21APIError) -> None: class BaseBedrockClient: def __init__(self, region: str, session: Optional[boto3.Session]): self._aws_auth = AWSAuthorization(aws_session=session or boto3.Session(region_name=region)) + self._streaming_decoder = AWSEventStreamDecoder() def _prepare_options(self, options: RequestOptions) -> RequestOptions: body = options.body @@ -121,7 +122,7 @@ def _prepare_url(self, options: RequestOptions) -> str: return options.url def _get_streaming_decoder(self) -> AWSEventStreamDecoder: - return AWSEventStreamDecoder() + return self._streaming_decoder class AsyncAI21BedrockClient(AsyncAI21HTTPClient, BaseBedrockClient): @@ -173,4 +174,4 @@ def _prepare_url(self, options: RequestOptions) -> str: return options.url def _get_streaming_decoder(self) -> AWSEventStreamDecoder: - return AWSEventStreamDecoder() + return self._streaming_decoder diff --git a/ai21/clients/studio/resources/studio_resource.py b/ai21/clients/studio/resources/studio_resource.py index 36fc05ad..38392433 100644 --- a/ai21/clients/studio/resources/studio_resource.py +++ b/ai21/clients/studio/resources/studio_resource.py @@ -9,6 +9,7 @@ from ai21.http_client.async_http_client import AsyncAI21HTTPClient from ai21.http_client.http_client import AI21HTTPClient from ai21.models._pydantic_compatibility import _from_dict +from ai21.stream.stream_commons import SSEDecoderBase from ai21.types import ResponseT, StreamT, AsyncStreamT from ai21.utils.typing import extract_type @@ -18,7 +19,7 @@ def _cast_response( response_cls: Optional[ResponseT], stream_cls: Optional[AsyncStreamT] = None, stream: bool = False, - streaming_decoder: Optional[Any] = None, + streaming_decoder: Optional[SSEDecoderBase] = None, ) -> ResponseT | AsyncStreamT | None: if stream and stream_cls is not None: cast_to = extract_type(stream_cls) diff --git a/ai21/errors.py b/ai21/errors.py index f295fb13..da0ef86a 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -96,9 +96,9 @@ def __init__(self, key: str): class StreamingDecodeError(AI21Error): def __init__(self, chunk: str, error_message: Optional[str] = None): - message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format" + message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format. " if error_message: - message += f". Error: {error_message}" + message += f"Error: {error_message}" super().__init__(message) diff --git a/ai21/http_client/async_http_client.py b/ai21/http_client/async_http_client.py index 673e4f1e..13fd536d 100644 --- a/ai21/http_client/async_http_client.py +++ b/ai21/http_client/async_http_client.py @@ -14,6 +14,7 @@ RETRY_BACK_OFF_FACTOR, TIME_BETWEEN_RETRIES, ) +from ai21.stream.stream_commons import _SSEDecoder _logger = logging.getLogger(__name__) @@ -55,6 +56,7 @@ def __init__( retry=retry_if_result(self._should_retry), stop=stop_after_attempt(self._num_retries), )(self._request) + self._streaming_decoder = _SSEDecoder() async def execute_http_request( self, @@ -120,5 +122,5 @@ def _init_client(self, client: Optional[httpx.AsyncClient]) -> httpx.AsyncClient return httpx.AsyncClient() - def _get_streaming_decoder(self): - return None + def _get_streaming_decoder(self) -> _SSEDecoder: + return self._streaming_decoder diff --git a/ai21/http_client/http_client.py b/ai21/http_client/http_client.py index 93f0db09..7f5c4600 100644 --- a/ai21/http_client/http_client.py +++ b/ai21/http_client/http_client.py @@ -13,6 +13,7 @@ ) from ai21.models.request_options import RequestOptions from ai21.stream.stream import Stream +from ai21.stream.stream_commons import _SSEDecoder _logger = logging.getLogger(__name__) @@ -54,6 +55,7 @@ def __init__( retry=retry_if_result(self._should_retry), stop=stop_after_attempt(self._num_retries), )(self._request) + self._streaming_decoder = _SSEDecoder() def execute_http_request( self, @@ -118,5 +120,5 @@ def _init_client(self, client: Optional[httpx.Client]) -> httpx.Client: return httpx.Client() - def _get_streaming_decoder(self): - return None + def _get_streaming_decoder(self) -> _SSEDecoder: + return self._streaming_decoder diff --git a/ai21/models/ai21_base_model.py b/ai21/models/ai21_base_model.py index 9c237ca9..a1db0422 100644 --- a/ai21/models/ai21_base_model.py +++ b/ai21/models/ai21_base_model.py @@ -1,13 +1,11 @@ import warnings from typing import Any, Dict + from pydantic import BaseModel, ConfigDict from typing_extensions import Self from ai21.models._pydantic_compatibility import _to_dict, _to_json, _from_dict, _from_json, IS_PYDANTIC_V2 -if not IS_PYDANTIC_V2: - from pydantic import Extra - class AI21BaseModel(BaseModel): if IS_PYDANTIC_V2: @@ -19,6 +17,8 @@ class AI21BaseModel(BaseModel): else: class Config: + from pydantic import Extra + allow_population_by_field_name = True extra = Extra.allow diff --git a/ai21/stream/stream_commons.py b/ai21/stream/stream_commons.py index 13e84623..d8feda7c 100644 --- a/ai21/stream/stream_commons.py +++ b/ai21/stream/stream_commons.py @@ -1,13 +1,13 @@ from __future__ import annotations import json +from abc import ABC, abstractmethod from typing import TypeVar, Iterator, AsyncIterator, Optional import httpx from ai21.errors import StreamingDecodeError - _T = TypeVar("_T") _SSE_DATA_PREFIX = "data: " _SSE_DONE_MSG = "[DONE]" @@ -24,7 +24,17 @@ def get_stream_message(chunk: str, cast_to: type[_T]) -> Iterator[_T] | AsyncIte raise StreamingDecodeError(chunk) -class _SSEDecoder: +class SSEDecoderBase(ABC): + @abstractmethod + def iter(self, response: httpx.Response) -> Iterator[str]: + pass + + @abstractmethod + async def aiter(self, response: httpx.Response) -> AsyncIterator[str]: + pass + + +class _SSEDecoder(SSEDecoderBase): def iter(self, response: httpx.Response): for line in response.iter_lines(): line = line.strip() diff --git a/tests/integration_tests/clients/bedrock/test_chat_completions.py b/tests/integration_tests/clients/bedrock/test_chat_completions.py index d1132143..c067f34a 100644 --- a/tests/integration_tests/clients/bedrock/test_chat_completions.py +++ b/tests/integration_tests/clients/bedrock/test_chat_completions.py @@ -19,12 +19,7 @@ def test_chat_completions__when_stream__last_chunk_should_hold_bedrock_metrics() stream=True, ) - last_chunk = None - for chunk in response: - assert chunk.id is not None - assert chunk.choices is not None - last_chunk = chunk - + last_chunk = list(response)[-1] chunk_dict = _to_dict(last_chunk) assert "amazon-bedrock-invocationMetrics" in chunk_dict @@ -38,11 +33,6 @@ async def test__async_chat_completions__when_stream__last_chunk_should_hold_bedr stream=True, ) - last_chunk = None - async for chunk in response: - assert chunk.id is not None - assert chunk.choices is not None - last_chunk = chunk - + last_chunk = [chunk async for chunk in response][-1] chunk_dict = _to_dict(last_chunk) assert "amazon-bedrock-invocationMetrics" in chunk_dict From 8616f4e7caa4fd8a3105b7c289bb34a97b245c05 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 5 Aug 2024 18:17:22 +0300 Subject: [PATCH 09/11] refactor: cr comments --- ai21/clients/bedrock/_stream_decoder.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/ai21/clients/bedrock/_stream_decoder.py b/ai21/clients/bedrock/_stream_decoder.py index 0720bed2..e58d6bc0 100644 --- a/ai21/clients/bedrock/_stream_decoder.py +++ b/ai21/clients/bedrock/_stream_decoder.py @@ -13,6 +13,9 @@ from ai21.stream.stream_commons import SSEDecoderBase +_FINISH_REASON_NULL_STR = '"finish_reason":null' + + @lru_cache(maxsize=None) def get_response_stream_shape() -> Shape: from botocore.model import ServiceModel @@ -32,11 +35,14 @@ def iter(self, response: httpx.Response) -> Iterator[str]: event_stream_buffer = EventStreamBuffer() previous_item = None for chunk in response.iter_bytes(): - item = next(self._process_chunks(event_stream_buffer, chunk)) + try: + item = next(self._process_chunks(event_stream_buffer, chunk)) + except StopIteration as e: + raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) # For Bedrock metering chunk: if previous_item is not None: item = self._build_last_chunk(last_model_chunk=previous_item, bedrock_metrics_chunk=item) - if '"finish_reason":null' not in item and previous_item is None: + if _FINISH_REASON_NULL_STR not in item and previous_item is None: previous_item = item continue yield item @@ -45,11 +51,14 @@ async def aiter(self, response: httpx.Response) -> AsyncIterator[str]: event_stream_buffer = EventStreamBuffer() previous_item = None async for chunk in response.aiter_bytes(): - item = next(self._process_chunks(event_stream_buffer, chunk)) + try: + item = next(self._process_chunks(event_stream_buffer, chunk)) + except StopIteration as e: + raise StreamingDecodeError(chunk=str(chunk), error_message=str(e)) # For Bedrock metering chunk: if previous_item is not None: item = self._build_last_chunk(last_model_chunk=previous_item, bedrock_metrics_chunk=item) - if '"finish_reason":null' not in item and previous_item is None: + if _FINISH_REASON_NULL_STR not in item and previous_item is None: previous_item = item continue yield item From 4f4886952a04a4b84e3abc109fb8f97f5b2d17e3 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 5 Aug 2024 18:27:44 +0300 Subject: [PATCH 10/11] refactor: async stream example --- examples/bedrock/chat/async_stream_chat_completions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/bedrock/chat/async_stream_chat_completions.py b/examples/bedrock/chat/async_stream_chat_completions.py index e8b1136b..637f45ca 100644 --- a/examples/bedrock/chat/async_stream_chat_completions.py +++ b/examples/bedrock/chat/async_stream_chat_completions.py @@ -4,9 +4,12 @@ client = AsyncAI21BedrockClient(region="us-east-1") # region is optional, as you can use the env variable instead +system = "You're a support engineer in a SaaS company" messages = [ - ChatMessage(content="You are a helpful assistant", role="system"), - ChatMessage(content="What is the meaning of life?", role="user"), + ChatMessage(content=system, role="system"), + ChatMessage(content="Hello, I need help with a signup process.", role="user"), + ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), + ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), ] From fad2c494fd332b380883c65c3375d64fb8443104 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 5 Aug 2024 19:20:52 +0300 Subject: [PATCH 11/11] refactor: cr comments --- ai21/clients/bedrock/_stream_decoder.py | 6 +++--- ai21/clients/bedrock/ai21_bedrock_client.py | 8 ++++---- ai21/clients/studio/resources/studio_resource.py | 4 ++-- ai21/errors.py | 4 ++-- ai21/stream/stream_commons.py | 4 ++-- tests/unittests/test_aws_stream_decoder.py | 10 +++++----- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/ai21/clients/bedrock/_stream_decoder.py b/ai21/clients/bedrock/_stream_decoder.py index e58d6bc0..74bb19a9 100644 --- a/ai21/clients/bedrock/_stream_decoder.py +++ b/ai21/clients/bedrock/_stream_decoder.py @@ -10,7 +10,7 @@ from botocore.parsers import EventStreamJSONParser from ai21.errors import StreamingDecodeError -from ai21.stream.stream_commons import SSEDecoderBase +from ai21.stream.stream_commons import _SSEDecoderBase _FINISH_REASON_NULL_STR = '"finish_reason":null' @@ -27,7 +27,7 @@ def get_response_stream_shape() -> Shape: return bedrock_service_model.shape_for("ResponseStream") -class AWSEventStreamDecoder(SSEDecoderBase): +class _AWSEventStreamDecoder(_SSEDecoderBase): def __init__(self) -> None: self._parser = EventStreamJSONParser() @@ -81,7 +81,7 @@ def _build_last_chunk(self, last_model_chunk: str, bedrock_metrics_chunk: str) - chunk_dict = {**chunk_dict, **bedrock_metrics_dict} return json.dumps(chunk_dict) - def _process_chunks(self, event_stream_buffer, chunk): + def _process_chunks(self, event_stream_buffer, chunk) -> Iterator[str]: try: event_stream_buffer.add_data(chunk) for event in event_stream_buffer: diff --git a/ai21/clients/bedrock/ai21_bedrock_client.py b/ai21/clients/bedrock/ai21_bedrock_client.py index f12a2c99..0eded54f 100644 --- a/ai21/clients/bedrock/ai21_bedrock_client.py +++ b/ai21/clients/bedrock/ai21_bedrock_client.py @@ -9,7 +9,7 @@ from ai21 import AI21APIError from ai21.ai21_env_config import AI21EnvConfig from ai21.clients.aws.aws_authorization import AWSAuthorization -from ai21.clients.bedrock._stream_decoder import AWSEventStreamDecoder +from ai21.clients.bedrock._stream_decoder import _AWSEventStreamDecoder from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion, AsyncStudioCompletion from ai21.errors import AccessDenied, NotFound, APITimeoutError, ModelErrorException @@ -43,7 +43,7 @@ def _handle_bedrock_error(aws_error: AI21APIError) -> None: class BaseBedrockClient: def __init__(self, region: str, session: Optional[boto3.Session]): self._aws_auth = AWSAuthorization(aws_session=session or boto3.Session(region_name=region)) - self._streaming_decoder = AWSEventStreamDecoder() + self._streaming_decoder = _AWSEventStreamDecoder() def _prepare_options(self, options: RequestOptions) -> RequestOptions: body = options.body @@ -121,7 +121,7 @@ def _build_request(self, options: RequestOptions) -> httpx.Request: def _prepare_url(self, options: RequestOptions) -> str: return options.url - def _get_streaming_decoder(self) -> AWSEventStreamDecoder: + def _get_streaming_decoder(self) -> _AWSEventStreamDecoder: return self._streaming_decoder @@ -173,5 +173,5 @@ def _build_request(self, options: RequestOptions) -> httpx.Request: def _prepare_url(self, options: RequestOptions) -> str: return options.url - def _get_streaming_decoder(self) -> AWSEventStreamDecoder: + def _get_streaming_decoder(self) -> _AWSEventStreamDecoder: return self._streaming_decoder diff --git a/ai21/clients/studio/resources/studio_resource.py b/ai21/clients/studio/resources/studio_resource.py index 38392433..e4b6634e 100644 --- a/ai21/clients/studio/resources/studio_resource.py +++ b/ai21/clients/studio/resources/studio_resource.py @@ -9,7 +9,7 @@ from ai21.http_client.async_http_client import AsyncAI21HTTPClient from ai21.http_client.http_client import AI21HTTPClient from ai21.models._pydantic_compatibility import _from_dict -from ai21.stream.stream_commons import SSEDecoderBase +from ai21.stream.stream_commons import _SSEDecoderBase from ai21.types import ResponseT, StreamT, AsyncStreamT from ai21.utils.typing import extract_type @@ -19,7 +19,7 @@ def _cast_response( response_cls: Optional[ResponseT], stream_cls: Optional[AsyncStreamT] = None, stream: bool = False, - streaming_decoder: Optional[SSEDecoderBase] = None, + streaming_decoder: Optional[_SSEDecoderBase] = None, ) -> ResponseT | AsyncStreamT | None: if stream and stream_cls is not None: cast_to = extract_type(stream_cls) diff --git a/ai21/errors.py b/ai21/errors.py index da0ef86a..918091a1 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -96,9 +96,9 @@ def __init__(self, key: str): class StreamingDecodeError(AI21Error): def __init__(self, chunk: str, error_message: Optional[str] = None): - message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format. " + message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format." if error_message: - message += f"Error: {error_message}" + message = f"{message} Error: {error_message}" super().__init__(message) diff --git a/ai21/stream/stream_commons.py b/ai21/stream/stream_commons.py index d8feda7c..bf9dd8cd 100644 --- a/ai21/stream/stream_commons.py +++ b/ai21/stream/stream_commons.py @@ -24,7 +24,7 @@ def get_stream_message(chunk: str, cast_to: type[_T]) -> Iterator[_T] | AsyncIte raise StreamingDecodeError(chunk) -class SSEDecoderBase(ABC): +class _SSEDecoderBase(ABC): @abstractmethod def iter(self, response: httpx.Response) -> Iterator[str]: pass @@ -34,7 +34,7 @@ async def aiter(self, response: httpx.Response) -> AsyncIterator[str]: pass -class _SSEDecoder(SSEDecoderBase): +class _SSEDecoder(_SSEDecoderBase): def iter(self, response: httpx.Response): for line in response.iter_lines(): line = line.strip() diff --git a/tests/unittests/test_aws_stream_decoder.py b/tests/unittests/test_aws_stream_decoder.py index b255f8d6..178bce67 100644 --- a/tests/unittests/test_aws_stream_decoder.py +++ b/tests/unittests/test_aws_stream_decoder.py @@ -3,7 +3,7 @@ import httpx import pytest -from ai21.clients.bedrock._stream_decoder import AWSEventStreamDecoder +from ai21.clients.bedrock._stream_decoder import _AWSEventStreamDecoder from ai21.errors import StreamingDecodeError from ai21.models.chat import ChatCompletionChunk from ai21.stream.async_stream import AsyncStream @@ -48,7 +48,7 @@ def test_stream_object_when_json_string_ok__should_be_ok(): stream = byte_stream() response = httpx.Response(status_code=200, content=stream) stream_obj = Stream[ChatCompletionChunk]( - response=response, cast_to=ChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + response=response, cast_to=ChatCompletionChunk, streaming_decoder=_AWSEventStreamDecoder() ) chunk_counter = 0 @@ -64,7 +64,7 @@ async def test_async_stream_object_when_json_string_ok__should_be_ok(): stream = async_byte_stream() response = httpx.Response(status_code=200, content=stream) stream_obj = AsyncStream[ChatCompletionChunk]( - response=response, cast_to=ChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + response=response, cast_to=ChatCompletionChunk, streaming_decoder=_AWSEventStreamDecoder() ) chunk_counter = 0 @@ -79,7 +79,7 @@ def test_stream_object_when_bad_json__should_raise_error(): stream = byte_bad_stream_json_format() response = httpx.Response(status_code=200, content=stream) stream_obj = Stream[ChatCompletionChunk]( - response=response, cast_to=ChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + response=response, cast_to=ChatCompletionChunk, streaming_decoder=_AWSEventStreamDecoder() ) with pytest.raises(StreamingDecodeError): @@ -92,7 +92,7 @@ async def test_async_stream_object_when_bad_json__should_raise_error(): stream = async_byte_bad_stream_json_format() response = httpx.Response(status_code=200, content=stream) stream_obj = AsyncStream[ChatCompletionChunk]( - response=response, cast_to=ChatCompletionChunk, streaming_decoder=AWSEventStreamDecoder() + response=response, cast_to=ChatCompletionChunk, streaming_decoder=_AWSEventStreamDecoder() ) with pytest.raises(StreamingDecodeError):