Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,30 @@ For a more detailed example, see the completion [examples](examples/studio/compl

---

## Streaming

We currently support streaming for the Chat Completions API in Jamba.

```python
from ai21 import AI21Client
from ai21.models.chat import ChatMessage

messages = [ChatMessage(content="What is the meaning of life?", role="user")]

client = AI21Client()

response = client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
stream=True,
)
for chunk in response:
print(chunk.choices[0].delta.content, end="")

```

---

## TSMs

AI21 Studio's Task-Specific Models offer a range of powerful tools. These models have been specifically designed for their respective tasks and provide high-quality results while optimizing efficiency.
Expand Down
7 changes: 5 additions & 2 deletions ai21/ai21_http_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import platform
from typing import Optional, Dict, Any, BinaryIO

import httpx

from ai21.errors import MissingApiKeyError
from ai21.http_client import HttpClient
from ai21.version import VERSION
Expand Down Expand Up @@ -76,9 +78,10 @@ def execute_http_request(
method: str,
url: str,
params: Optional[Dict] = None,
stream: bool = False,
files: Optional[Dict[str, BinaryIO]] = None,
):
return self._http_client.execute_http_request(method=method, url=url, params=params, files=files)
) -> httpx.Response:
return self._http_client.execute_http_request(method=method, url=url, params=params, files=files, stream=stream)

def get_base_url(self) -> str:
return f"{self._api_host}/studio/{self._api_version}"
3 changes: 0 additions & 3 deletions ai21/clients/common/chat_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ def create(
def completions(self) -> ChatCompletions:
pass

def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse:
return ChatResponse.from_dict(json)

def _create_body(
self,
model: str,
Expand Down
5 changes: 1 addition & 4 deletions ai21/clients/common/completion_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Dict, Any
from typing import List, Dict

from ai21.models import Penalty, CompletionsResponse
from ai21.types import NOT_GIVEN, NotGiven
Expand Down Expand Up @@ -55,9 +55,6 @@ def create(
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> CompletionsResponse:
return CompletionsResponse.from_dict(json)

def _create_body(
self,
model: str,
Expand Down
3 changes: 0 additions & 3 deletions ai21/clients/common/custom_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def list(self) -> List[CustomBaseModelResponse]:
def get(self, resource_id: str) -> CustomBaseModelResponse:
pass

def _json_to_response(self, json: Dict[str, Any]) -> CustomBaseModelResponse:
return CustomBaseModelResponse.from_dict(json)

def _create_body(
self,
dataset_id: str,
Expand Down
5 changes: 0 additions & 5 deletions ai21/clients/common/dataset_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional, Any, Dict

from ai21.models import DatasetResponse


class Dataset(ABC):
_module_name = "dataset"
Expand Down Expand Up @@ -40,9 +38,6 @@ def list(self):
def get(self, dataset_pid: str):
pass

def _json_to_response(self, json: Dict[str, Any]) -> DatasetResponse:
return DatasetResponse.from_dict(json)

def _create_body(
self,
dataset_name: str,
Expand Down
3 changes: 0 additions & 3 deletions ai21/clients/common/embed_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,5 @@ def create(self, texts: List[str], *, type: Optional[EmbedType] = None, **kwargs
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> EmbedResponse:
return EmbedResponse.from_dict(json)

def _create_body(self, texts: List[str], type: Optional[str], **kwargs) -> Dict[str, Any]:
return {"texts": texts, "type": type, **kwargs}
3 changes: 0 additions & 3 deletions ai21/clients/common/improvements_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,5 @@ def create(self, text: str, types: List[ImprovementType], **kwargs) -> Improveme
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> ImprovementsResponse:
return ImprovementsResponse.from_dict(json)

def _create_body(self, text: str, types: List[str], **kwargs) -> Dict[str, Any]:
return {"text": text, "types": types, **kwargs}
3 changes: 0 additions & 3 deletions ai21/clients/common/segmentation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,5 @@ def create(self, source: str, source_type: DocumentType, **kwargs) -> Segmentati
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> SegmentationResponse:
return SegmentationResponse.from_dict(json)

def _create_body(self, source: str, source_type: str, **kwargs) -> Dict[str, Any]:
return {"source": source, "sourceType": source_type, **kwargs}
3 changes: 0 additions & 3 deletions ai21/clients/common/summarize_by_segment_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ def create(
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> SummarizeBySegmentResponse:
return SummarizeBySegmentResponse.from_dict(json)

def _create_body(
self,
source: str,
Expand Down
51 changes: 44 additions & 7 deletions ai21/clients/studio/resources/chat/chat_completions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from typing import List, Optional, Union, Any, Dict
from typing import List, Optional, Union, Any, Dict, Literal, overload

from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.models.chat import ChatMessage, ChatCompletionResponse
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk
from ai21.stream import Stream
from ai21.types import NotGiven, NOT_GIVEN
from ai21.utils.typing import remove_not_given

Expand All @@ -14,6 +15,7 @@
class ChatCompletions(StudioResource):
_module_name = "chat/completions"

@overload
def create(
self,
model: str,
Expand All @@ -23,8 +25,38 @@ def create(
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse:
pass

@overload
def create(
self,
model: str,
messages: List[ChatMessage],
stream: Literal[True],
max_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> Stream[ChatCompletionChunk]:
pass

def create(
self,
model: str,
messages: List[ChatMessage],
max_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse | Stream[ChatCompletionChunk]:
if any(isinstance(item, J2ChatMessage) for item in messages):
raise ValueError(
"Please use the ChatMessage class from ai21.models.chat"
Expand All @@ -39,12 +71,18 @@ def create(
max_tokens=max_tokens,
top_p=top_p,
n=n,
stream=stream or False,
**kwargs,
)

url = f"{self._client.get_base_url()}/{self._module_name}"
response = self._post(url=url, body=body)
return self._json_to_response(response)
return self._post(
url=url,
body=body,
stream=stream or False,
stream_cls=Stream[ChatCompletionChunk],
response_cls=ChatCompletionResponse,
)

def _create_body(
self,
Expand All @@ -55,6 +93,7 @@ def _create_body(
top_p: Optional[float] | NotGiven,
stop: Optional[Union[str, List[str]]] | NotGiven,
n: Optional[int] | NotGiven,
stream: Literal[False] | Literal[True] | NotGiven,
**kwargs: Any,
) -> Dict[str, Any]:
return remove_not_given(
Expand All @@ -66,9 +105,7 @@ def _create_body(
"topP": top_p,
"stop": stop,
"n": n,
"stream": stream,
**kwargs,
}
)

def _json_to_response(self, json: Dict[str, Any]) -> ChatCompletionResponse:
return ChatCompletionResponse.from_dict(json)
4 changes: 1 addition & 3 deletions ai21/clients/studio/resources/studio_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,4 @@ def create(

body = self._create_body(context=context, question=question, **kwargs)

response = self._post(url=url, body=body)

return self._json_to_response(response)
return self._post(url=url, body=body, response_cls=AnswerResponse)
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def create(
**kwargs,
)
url = f"{self._client.get_base_url()}/{model}/{self._module_name}"
response = self._post(url=url, body=body)
return self._json_to_response(response)
return self._post(url=url, body=body, response_cls=ChatResponse)

@property
def completions(self) -> ChatCompletions:
Expand Down
2 changes: 1 addition & 1 deletion ai21/clients/studio/resources/studio_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def create(
logit_bias=logit_bias,
**kwargs,
)
return self._json_to_response(self._post(url=url, body=body))
return self._post(url=url, body=body, response_cls=CompletionsResponse)
8 changes: 3 additions & 5 deletions ai21/clients/studio/resources/studio_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@ def create(
num_epochs=num_epochs,
**kwargs,
)
self._post(url=url, body=body)
self._post(url=url, body=body, response_cls=None)

def list(self) -> List[CustomBaseModelResponse]:
url = f"{self._client.get_base_url()}/{self._module_name}"
response = self._get(url=url)

return [self._json_to_response(r) for r in response]
return self._get(url=url, response_cls=List[CustomBaseModelResponse])

def get(self, resource_id: str) -> CustomBaseModelResponse:
url = f"{self._client.get_base_url()}/{self._module_name}/{resource_id}"
return self._json_to_response(self._get(url=url))
return self._get(url=url, response_cls=CustomBaseModelResponse)
7 changes: 2 additions & 5 deletions ai21/clients/studio/resources/studio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,11 @@ def create(
)

def list(self) -> List[DatasetResponse]:
response = self._get(url=self._base_url())
return [self._json_to_response(r) for r in response]
return self._get(url=self._base_url(), response_cls=List[DatasetResponse])

def get(self, dataset_pid: str) -> DatasetResponse:
url = f"{self._base_url()}/{dataset_pid}"
response = self._get(url=url)

return self._json_to_response(response)
return self._get(url=url, response_cls=DatasetResponse)

def _base_url(self) -> str:
return f"{self._client.get_base_url()}/{self._module_name}"
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ class StudioEmbed(StudioResource, Embed):
def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
body = self._create_body(texts=texts, type=type, **kwargs)
response = self._post(url=url, body=body)

return self._json_to_response(response)
return self._post(url=url, body=body, response_cls=EmbedResponse)
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_gec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,5 @@ class StudioGEC(StudioResource, GEC):
def create(self, text: str, **kwargs) -> GECResponse:
body = self._create_body(text=text, **kwargs)
url = f"{self._client.get_base_url()}/{self._module_name}"
response = self._post(url=url, body=body)

return self._json_to_response(response)
return self._post(url=url, body=body, response_cls=GECResponse)
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_improvements.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,5 @@ def create(self, text: str, types: List[ImprovementType], **kwargs) -> Improveme

url = f"{self._client.get_base_url()}/{self._module_name}"
body = self._create_body(text=text, types=types, **kwargs)
response = self._post(url=url, body=body)

return self._json_to_response(response)
return self._post(url=url, body=body, response_cls=ImprovementsResponse)
Loading