Skip to content
Merged
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,32 @@ response = client.chat.completions.create(
)
```

#### Stream

```python
from ai21 import AI21BedrockClient, BedrockModelID
from ai21.models.chat import ChatMessage

system = "You're a support engineer in a SaaS company"
messages = [
ChatMessage(content=system, role="system"),
ChatMessage(content="Hello, I need help with a signup process.", role="user"),
ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"),
ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"),
]

client = AI21BedrockClient()

response = client.chat.completions.create(
messages=messages,
model=BedrockModelID.JAMBA_INSTRUCT_V1,
stream=True,
)

for chunk in response:
print(chunk.choices[0].message.content, end="")
```

#### Async

```python
Expand Down
92 changes: 92 additions & 0 deletions ai21/clients/bedrock/_stream_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

import json
from functools import lru_cache
from typing import Iterator, AsyncIterator

import httpx
from botocore.eventstream import EventStreamMessage, EventStreamBuffer
from botocore.model import Shape
from botocore.parsers import EventStreamJSONParser

from ai21.errors import StreamingDecodeError
from ai21.stream.stream_commons import _SSEDecoderBase


_FINISH_REASON_NULL_STR = '"finish_reason":null'


@lru_cache(maxsize=None)
def get_response_stream_shape() -> Shape:
from botocore.model import ServiceModel
from botocore.loaders import Loader

loader = Loader()
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2")
bedrock_service_model = ServiceModel(bedrock_service_dict)
return bedrock_service_model.shape_for("ResponseStream")


class _AWSEventStreamDecoder(_SSEDecoderBase):
def __init__(self) -> None:
self._parser = EventStreamJSONParser()

def iter(self, response: httpx.Response) -> Iterator[str]:
event_stream_buffer = EventStreamBuffer()
previous_item = None
for chunk in response.iter_bytes():
try:
item = next(self._process_chunks(event_stream_buffer, chunk))
except StopIteration as e:
raise StreamingDecodeError(chunk=str(chunk), error_message=str(e))
# For Bedrock metering chunk:
if previous_item is not None:
item = self._build_last_chunk(last_model_chunk=previous_item, bedrock_metrics_chunk=item)
if _FINISH_REASON_NULL_STR not in item and previous_item is None:
previous_item = item
continue
yield item

async def aiter(self, response: httpx.Response) -> AsyncIterator[str]:
event_stream_buffer = EventStreamBuffer()
previous_item = None
async for chunk in response.aiter_bytes():
try:
item = next(self._process_chunks(event_stream_buffer, chunk))
except StopIteration as e:
raise StreamingDecodeError(chunk=str(chunk), error_message=str(e))
# For Bedrock metering chunk:
if previous_item is not None:
item = self._build_last_chunk(last_model_chunk=previous_item, bedrock_metrics_chunk=item)
if _FINISH_REASON_NULL_STR not in item and previous_item is None:
previous_item = item
continue
yield item

def _parse_message_from_event(self, event: EventStreamMessage) -> str | None:
response_dict = event.to_response_dict()
parsed_response = self._parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")

chunk = parsed_response.get("chunk")
if not chunk:
return None

return chunk.get("bytes").decode() # type: ignore[no-any-return]

def _build_last_chunk(self, last_model_chunk: str, bedrock_metrics_chunk: str) -> str:
chunk_dict = json.loads(last_model_chunk)
bedrock_metrics_dict = json.loads(bedrock_metrics_chunk)
chunk_dict = {**chunk_dict, **bedrock_metrics_dict}
return json.dumps(chunk_dict)

def _process_chunks(self, event_stream_buffer, chunk) -> Iterator[str]:
try:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
yield message
except Exception as e:
raise StreamingDecodeError(chunk=str(chunk), error_message=str(e))
14 changes: 11 additions & 3 deletions ai21/clients/bedrock/ai21_bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ai21 import AI21APIError
from ai21.ai21_env_config import AI21EnvConfig
from ai21.clients.aws.aws_authorization import AWSAuthorization
from ai21.clients.bedrock._stream_decoder import _AWSEventStreamDecoder
from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat
from ai21.clients.studio.resources.studio_completion import StudioCompletion, AsyncStudioCompletion
from ai21.errors import AccessDenied, NotFound, APITimeoutError, ModelErrorException
Expand Down Expand Up @@ -42,18 +43,19 @@ def _handle_bedrock_error(aws_error: AI21APIError) -> None:
class BaseBedrockClient:
def __init__(self, region: str, session: Optional[boto3.Session]):
self._aws_auth = AWSAuthorization(aws_session=session or boto3.Session(region_name=region))
self._streaming_decoder = _AWSEventStreamDecoder()

def _prepare_options(self, options: RequestOptions) -> RequestOptions:
body = options.body

model = body.pop("model", None)
stream = body.pop("stream", False)

endpoint = "invoke"
if stream:
_logger.warning("Field stream is not supported. Ignoring it.")
endpoint = "invoke-with-response-stream"

# When stream is supported we would need to update this section and the URL
url = f"{options.url}/model/{model}/invoke"
url = f"{options.url}/model/{model}/{endpoint}"
headers = self._prepare_headers(url=url, body=body)

return options.replace(
Expand Down Expand Up @@ -119,6 +121,9 @@ def _build_request(self, options: RequestOptions) -> httpx.Request:
def _prepare_url(self, options: RequestOptions) -> str:
return options.url

def _get_streaming_decoder(self) -> _AWSEventStreamDecoder:
return self._streaming_decoder


class AsyncAI21BedrockClient(AsyncAI21HTTPClient, BaseBedrockClient):
def __init__(
Expand Down Expand Up @@ -167,3 +172,6 @@ def _build_request(self, options: RequestOptions) -> httpx.Request:

def _prepare_url(self, options: RequestOptions) -> str:
return options.url

def _get_streaming_decoder(self) -> _AWSEventStreamDecoder:
return self._streaming_decoder
20 changes: 17 additions & 3 deletions ai21/clients/studio/resources/studio_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient
from ai21.models._pydantic_compatibility import _from_dict
from ai21.stream.stream_commons import _SSEDecoderBase
from ai21.types import ResponseT, StreamT, AsyncStreamT
from ai21.utils.typing import extract_type

Expand All @@ -18,10 +19,11 @@ def _cast_response(
response_cls: Optional[ResponseT],
stream_cls: Optional[AsyncStreamT] = None,
stream: bool = False,
streaming_decoder: Optional[_SSEDecoderBase] = None,
) -> ResponseT | AsyncStreamT | None:
if stream and stream_cls is not None:
cast_to = extract_type(stream_cls)
return stream_cls(cast_to=cast_to, response=response)
return stream_cls(cast_to=cast_to, response=response, streaming_decoder=streaming_decoder)

if response_cls is None:
return None
Expand Down Expand Up @@ -64,7 +66,13 @@ def _post(
files=files,
)

return _cast_response(stream=stream, response=response, response_cls=response_cls, stream_cls=stream_cls)
return _cast_response(
stream=stream,
response=response,
response_cls=response_cls,
stream_cls=stream_cls,
streaming_decoder=self._client._get_streaming_decoder(),
)

def _get(
self, path: str, response_cls: Optional[ResponseT] = None, params: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -109,7 +117,13 @@ async def _post(
files=files,
)

return _cast_response(stream=stream, response=response, response_cls=response_cls, stream_cls=stream_cls)
return _cast_response(
stream=stream,
response=response,
response_cls=response_cls,
stream_cls=stream_cls,
streaming_decoder=self._client._get_streaming_decoder(),
)

async def _get(
self, path: str, response_cls: Optional[ResponseT] = None, params: Optional[Dict[str, Any]] = None
Expand Down
6 changes: 4 additions & 2 deletions ai21/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def __init__(self, key: str):


class StreamingDecodeError(AI21Error):
def __init__(self, chunk: str):
message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format"
def __init__(self, chunk: str, error_message: Optional[str] = None):
message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format."
if error_message:
message = f"{message} Error: {error_message}"
super().__init__(message)


Expand Down
5 changes: 5 additions & 0 deletions ai21/http_client/async_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
RETRY_BACK_OFF_FACTOR,
TIME_BETWEEN_RETRIES,
)
from ai21.stream.stream_commons import _SSEDecoder

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
retry=retry_if_result(self._should_retry),
stop=stop_after_attempt(self._num_retries),
)(self._request)
self._streaming_decoder = _SSEDecoder()

async def execute_http_request(
self,
Expand Down Expand Up @@ -119,3 +121,6 @@ def _init_client(self, client: Optional[httpx.AsyncClient]) -> httpx.AsyncClient
return httpx.AsyncClient(transport=_requests_retry_async_session(retries=self._num_retries))

return httpx.AsyncClient()

def _get_streaming_decoder(self) -> _SSEDecoder:
return self._streaming_decoder
5 changes: 5 additions & 0 deletions ai21/http_client/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from ai21.models.request_options import RequestOptions
from ai21.stream.stream import Stream
from ai21.stream.stream_commons import _SSEDecoder

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
retry=retry_if_result(self._should_retry),
stop=stop_after_attempt(self._num_retries),
)(self._request)
self._streaming_decoder = _SSEDecoder()

def execute_http_request(
self,
Expand Down Expand Up @@ -117,3 +119,6 @@ def _init_client(self, client: Optional[httpx.Client]) -> httpx.Client:
return httpx.Client(transport=_requests_retry_session(retries=self._num_retries))

return httpx.Client()

def _get_streaming_decoder(self) -> _SSEDecoder:
return self._streaming_decoder
5 changes: 5 additions & 0 deletions ai21/models/ai21_base_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from typing import Any, Dict

from pydantic import BaseModel, ConfigDict
from typing_extensions import Self

Expand All @@ -11,11 +12,15 @@ class AI21BaseModel(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
protected_namespaces=(),
extra="allow",
)
else:

class Config:
from pydantic import Extra

allow_population_by_field_name = True
extra = Extra.allow

def to_dict(self, **kwargs) -> Dict[str, Any]:
warnings.warn(
Expand Down
7 changes: 4 additions & 3 deletions ai21/stream/async_stream.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Generic, AsyncIterator
from typing import Generic, AsyncIterator, Optional, Any
from typing_extensions import Self
from types import TracebackType

Expand All @@ -17,10 +17,11 @@ def __init__(
*,
cast_to: type[_T],
response: httpx.Response,
streaming_decoder: Optional[Any] = None,
):
self.response = response
self.cast_to = cast_to
self._decoder = _SSEDecoder()
self._decoder = streaming_decoder or _SSEDecoder()
self._iterator = self.__stream__()

async def __anext__(self) -> _T:
Expand All @@ -31,7 +32,7 @@ async def __aiter__(self) -> AsyncIterator[_T]:
yield item

async def __stream__(self) -> AsyncIterator[_T]:
iterator = self._decoder.aiter(self.response.aiter_lines())
iterator = self._decoder.aiter(self.response)
async for chunk in iterator:
if chunk.endswith(_SSE_DONE_MSG):
break
Expand Down
7 changes: 4 additions & 3 deletions ai21/stream/stream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from types import TracebackType
from typing import Generic, Iterator
from typing import Generic, Iterator, Optional, Any

import httpx
from typing_extensions import Self
Expand All @@ -17,10 +17,11 @@ def __init__(
*,
cast_to: type[_T],
response: httpx.Response,
streaming_decoder: Optional[Any] = None,
):
self.response = response
self.cast_to = cast_to
self._decoder = _SSEDecoder()
self._decoder = streaming_decoder or _SSEDecoder()
self._iterator = self.__stream__()

def __next__(self) -> _T:
Expand All @@ -31,7 +32,7 @@ def __iter__(self) -> Iterator[_T]:
yield item

def __stream__(self) -> Iterator[_T]:
iterator = self._decoder.iter(self.response.iter_lines())
iterator = self._decoder.iter(self.response)
for chunk in iterator:
if chunk.endswith(_SSE_DONE_MSG):
break
Expand Down
Loading