Skip to content

Commit

Permalink
feat(streaming): add tools support
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed May 30, 2024
1 parent ad7adbd commit 9f00950
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 692 deletions.
2 changes: 1 addition & 1 deletion examples/tools_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


async def main() -> None:
async with client.beta.tools.messages.stream(
async with client.messages.stream(
max_tokens=1024,
model="claude-3-haiku-20240307",
tools=[
Expand Down
1 change: 1 addition & 0 deletions src/anthropic/lib/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._types import (
TextEvent as TextEvent,
InputJsonEvent as InputJsonEvent,
MessageStopEvent as MessageStopEvent,
MessageStreamEvent as MessageStreamEvent,
ContentBlockStopEvent as ContentBlockStopEvent,
Expand Down
93 changes: 86 additions & 7 deletions src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@

import asyncio
from types import TracebackType
from typing import TYPE_CHECKING, Generic, TypeVar, Callable
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, cast
from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never

import httpx

from ._types import TextEvent, MessageStopEvent, MessageStreamEvent, ContentBlockStopEvent
from ._types import (
TextEvent,
InputJsonEvent,
MessageStopEvent,
MessageStreamEvent,
ContentBlockStopEvent,
)
from ...types import Message, ContentBlock, RawMessageStreamEvent
from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._models import construct_type
from ..._streaming import Stream, AsyncStream

if TYPE_CHECKING:
Expand Down Expand Up @@ -139,6 +146,18 @@ def on_text(self, text: str, snapshot: str) -> None:
```
"""

def on_input_json(self, delta: str, snapshot: object) -> None:
"""Callback that is fired whenever a `input_json_delta` ContentBlock is yielded.
The first argument is the json string delta and the second is the current accumulated
parsed object, for example:
```
on_input_json('{"locations": ["San ', {"locations": []})
on_input_json('Francisco"]', {"locations": ["San Francisco"]})
```
"""

def on_exception(self, exception: Exception) -> None:
"""Fires if any exception occurs"""

Expand Down Expand Up @@ -201,6 +220,15 @@ def _emit_sse_event(self, event: RawMessageStreamEvent) -> list[MessageStreamEve
snapshot=content_block.text,
)
)
elif event.delta.type == "input_json_delta" and content_block.type == "tool_use":
self.on_input_json(event.delta.partial_json, content_block.input)
events_to_fire.append(
InputJsonEvent(
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
)
)
elif event.type == "content_block_stop":
content_block = self.current_message_snapshot.content[event.index]
self.on_content_block(content_block)
Expand Down Expand Up @@ -230,7 +258,9 @@ class MessageStreamManager(Generic[MessageStreamT]):
"""

def __init__(
self, api_request: Callable[[], Stream[RawMessageStreamEvent]], event_handler_cls: type[MessageStreamT]
self,
api_request: Callable[[], Stream[RawMessageStreamEvent]],
event_handler_cls: type[MessageStreamT],
) -> None:
self.__event_handler: MessageStreamT | None = None
self.__event_handler_cls: type[MessageStreamT] = event_handler_cls
Expand Down Expand Up @@ -382,6 +412,18 @@ async def on_text(self, text: str, snapshot: str) -> None:
```
"""

async def on_input_json(self, delta: str, snapshot: object) -> None:
"""Callback that is fired whenever a `input_json_delta` ContentBlock is yielded.
The first argument is the json string delta and the second is the current accumulated
parsed object, for example:
```
on_input_json('{"locations": ["San ', {"locations": []})
on_input_json('Francisco"]', {"locations": ["San Francisco"]})
```
"""

async def on_final_text(self, text: str) -> None:
"""Callback that is fired whenever a full `text` ContentBlock is accumulated.
Expand Down Expand Up @@ -450,10 +492,22 @@ async def _emit_sse_event(self, event: RawMessageStreamEvent) -> list[MessageStr
snapshot=content_block.text,
)
)
elif event.delta.type == "input_json_delta" and content_block.type == "tool_use":
await self.on_input_json(event.delta.partial_json, content_block.input)
events_to_fire.append(
InputJsonEvent(
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
)
)
elif event.type == "content_block_stop":
content_block = self.current_message_snapshot.content[event.index]
await self.on_content_block(content_block)

if content_block.type == "text":
await self.on_final_text(content_block.text)

events_to_fire.append(
ContentBlockStopEvent(type="content_block_stop", index=event.index, content_block=content_block),
)
Expand Down Expand Up @@ -481,7 +535,9 @@ class AsyncMessageStreamManager(Generic[AsyncMessageStreamT]):
"""

def __init__(
self, api_request: Awaitable[AsyncStream[RawMessageStreamEvent]], event_handler_cls: type[AsyncMessageStreamT]
self,
api_request: Awaitable[AsyncStream[RawMessageStreamEvent]],
event_handler_cls: type[AsyncMessageStreamT],
) -> None:
self.__event_handler: AsyncMessageStreamT | None = None
self.__event_handler_cls: type[AsyncMessageStreamT] = event_handler_cls
Expand All @@ -508,22 +564,45 @@ async def __aexit__(
await self.__event_handler.close()


def accumulate_event(*, event: RawMessageStreamEvent, current_snapshot: Message | None) -> Message:
JSON_BUF_PROPERTY = "__json_buf"


def accumulate_event(
*,
event: RawMessageStreamEvent,
current_snapshot: Message | None,
) -> Message:
if current_snapshot is None:
if event.type == "message_start":
return event.message
return Message.construct(**cast(Any, event.message.to_dict()))

raise RuntimeError(f'Unexpected event order, got {event.type} before "message_start"')

if event.type == "content_block_start":
# TODO: check index
current_snapshot.content.append(
ContentBlock.construct(**event.content_block.model_dump()),
cast(
ContentBlock,
construct_type(type_=ContentBlock, value=event.content_block.model_dump()),
),
)
elif event.type == "content_block_delta":
content = current_snapshot.content[event.index]
if content.type == "text" and event.delta.type == "text_delta":
content.text += event.delta.text
elif content.type == "tool_use" and event.delta.type == "input_json_delta":
from jiter import from_json

# we need to keep track of the raw JSON string as well so that we can
# re-parse it for each delta, for now we just store it as an untyped
# property on the snapshot
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")

if json_buf:
content.input = from_json(json_buf, partial_mode=True)

setattr(content, JSON_BUF_PROPERTY, json_buf)
elif event.type == "message_delta":
current_snapshot.stop_reason = event.delta.stop_reason
current_snapshot.stop_sequence = event.delta.stop_sequence
Expand Down
18 changes: 18 additions & 0 deletions src/anthropic/lib/streaming/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@ class TextEvent(BaseModel):
"""The entire accumulated text"""


class InputJsonEvent(BaseModel):
type: Literal["input_json"]

partial_json: str
"""A partial JSON string delta
e.g. `'"San Francisco,'`
"""

snapshot: object
"""The currently accumulated parsed object.
e.g. `{'location': 'San Francisco, CA'}`
"""


class MessageStopEvent(RawMessageStopEvent):
type: Literal["message_stop"]

Expand All @@ -38,6 +55,7 @@ class ContentBlockStopEvent(RawContentBlockStopEvent):

MessageStreamEvent = Union[
TextEvent,
InputJsonEvent,
RawMessageStartEvent,
RawMessageDeltaEvent,
MessageStopEvent,
Expand Down
14 changes: 0 additions & 14 deletions src/anthropic/lib/streaming/beta/__init__.py

This file was deleted.

Loading

0 comments on commit 9f00950

Please sign in to comment.