Skip to content
6 changes: 6 additions & 0 deletions ai21/clients/common/chat_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import List, Any, Dict, Optional

from ai21.clients.studio.resources.chat import ChatCompletions
from ai21.models import Penalty, ChatResponse, ChatMessage


Expand Down Expand Up @@ -47,6 +48,11 @@ def create(
"""
pass

@property
@abstractmethod
def completions(self) -> ChatCompletions:
pass

def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse:
return ChatResponse.from_dict(json)

Expand Down
2 changes: 1 addition & 1 deletion ai21/clients/studio/ai21_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
http_client=http_client,
)
self.completion = StudioCompletion(self._http_client)
self.chat = StudioChat(self._http_client)
self.chat: StudioChat = StudioChat(self._http_client)
self.summarize = StudioSummarize(self._http_client)
self.embed = StudioEmbed(self._http_client)
self.gec = StudioGEC(self._http_client)
Expand Down
3 changes: 3 additions & 0 deletions ai21/clients/studio/resources/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import annotations

from .chat_completions import ChatCompletions as ChatCompletions
83 changes: 83 additions & 0 deletions ai21/clients/studio/resources/chat/chat_completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

from typing import List, Optional, Union, Any, Dict

from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.models.chat import ChatMessage, ChatCompletionResponse
from ai21.types import NotGiven, NOT_GIVEN
from ai21.utils.typing import remove_not_given

__all__ = ["ChatCompletions"]


class ChatCompletions(StudioResource):
_module_name = "chat/complete"

def create(
self,
model: str,
messages: List[ChatMessage],
n: int | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
frequency_penalty: float | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse:
body = self._create_body(
model=model,
messages=messages,
n=n,
logprobs=logprobs,
top_logprobs=top_logprobs,
stop=stop,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
**kwargs,
)

url = f"{self._client.get_base_url()}/{self._module_name}"
response = self._post(url=url, body=body)
return self._json_to_response(response)

def _create_body(
self,
model: str,
messages: List[ChatMessage],
n: Optional[int] | NotGiven,
logprobs: Optional[bool] | NotGiven,
top_logprobs: Optional[int] | NotGiven,
max_tokens: Optional[int] | NotGiven,
temperature: Optional[float] | NotGiven,
top_p: Optional[float] | NotGiven,
stop: Optional[Union[str, List[str]]] | NotGiven,
frequency_penalty: Optional[float] | NotGiven,
presence_penalty: Optional[float] | NotGiven,
**kwargs: Any,
) -> Dict[str, Any]:
return remove_not_given(
{
"model": model,
"messages": [message.to_dict() for message in messages],
"temperature": temperature,
"maxTokens": max_tokens,
"n": n,
"topP": top_p,
"logprobs": logprobs,
"topLogprobs": top_logprobs,
"stop": stop,
"frequencyPenalty": frequency_penalty,
"presencePenalty": presence_penalty,
**kwargs,
}
)

def _json_to_response(self, json: Dict[str, Any]) -> ChatCompletionResponse:
return ChatCompletionResponse.from_dict(json)
5 changes: 5 additions & 0 deletions ai21/clients/studio/resources/studio_chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional

from ai21.clients.common.chat_base import Chat
from ai21.clients.studio.resources.chat import ChatCompletions
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.models import ChatMessage, Penalty, ChatResponse

Expand Down Expand Up @@ -42,3 +43,7 @@ def create(
url = f"{self._client.get_base_url()}/{model}/{self._module_name}"
response = self._post(url=url, body=body)
return self._json_to_response(response)

@property
def completions(self) -> ChatCompletions:
return ChatCompletions(self._client)
3 changes: 1 addition & 2 deletions ai21/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ai21.models.chat.role_type import RoleType
from ai21.models.chat_message import ChatMessage
from ai21.models.document_type import DocumentType
from ai21.models.embed_type import EmbedType
Expand Down Expand Up @@ -25,10 +26,8 @@
from ai21.models.responses.segmentation_response import SegmentationResponse
from ai21.models.responses.summarize_by_segment_response import SummarizeBySegmentResponse, SegmentSummary, Highlight
from ai21.models.responses.summarize_response import SummarizeResponse
from ai21.models.role_type import RoleType
from ai21.models.summary_method import SummaryMethod


__all__ = [
"ChatMessage",
"RoleType",
Expand Down
8 changes: 8 additions & 0 deletions ai21/models/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations

from .chat_completion_response import ChatCompletionResponse
from .chat_completion_response import ChatCompletionResponseChoice
from .chat_message import ChatMessage
from .role_type import RoleType as RoleType

__all__ = ["ChatCompletionResponse", "ChatCompletionResponseChoice", "ChatMessage", "RoleType"]
22 changes: 22 additions & 0 deletions ai21/models/chat/chat_completion_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dataclasses import dataclass
from typing import Optional, List

from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin
from ai21.models.logprobs import Logprobs
from ai21.models.usage_info import UsageInfo
from .chat_message import ChatMessage


@dataclass
class ChatCompletionResponseChoice(AI21BaseModelMixin):
index: int
message: ChatMessage
logprobs: Optional[Logprobs] = None
finish_reason: Optional[str] = None


@dataclass
class ChatCompletionResponse(AI21BaseModelMixin):
id: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
10 changes: 10 additions & 0 deletions ai21/models/chat/chat_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass

from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin
from .role_type import RoleType


@dataclass
class ChatMessage(AI21BaseModelMixin):
role: RoleType
content: str
File renamed without changes.
2 changes: 1 addition & 1 deletion ai21/models/chat_message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin
from ai21.models.role_type import RoleType
from ai21.models.chat.role_type import RoleType


@dataclass
Expand Down
22 changes: 22 additions & 0 deletions ai21/models/logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dataclasses import dataclass
from typing import List

from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin


@dataclass
class TopTokenData(AI21BaseModelMixin):
token: str
logprob: float


@dataclass
class LogprobsData(AI21BaseModelMixin):
token: str
logprob: float
top_logprobs: List[TopTokenData]


@dataclass
class Logprobs(AI21BaseModelMixin):
content: LogprobsData
2 changes: 1 addition & 1 deletion ai21/models/responses/chat_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, List

from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin
from ai21.models.role_type import RoleType
from ai21.models.chat.role_type import RoleType


@dataclass
Expand Down
10 changes: 10 additions & 0 deletions ai21/models/usage_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass

from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin


@dataclass
class UsageInfo(AI21BaseModelMixin):
prompt_tokens: int
completion_tokens: int
total_tokens: int
Empty file added examples/__init__.py
Empty file.
Empty file added examples/studio/__init__.py
Empty file.
11 changes: 9 additions & 2 deletions examples/studio/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""
This examples uses a deprecated method client.chat.create and instead
should be replaced with the `client.chat.completions.create`
"""

from ai21 import AI21Client
from ai21.models import ChatMessage, RoleType, Penalty
from ai21.models import RoleType, Penalty
from ai21.models import ChatMessage

system = "You're a support engineer in a SaaS company"
messages = [
Expand All @@ -8,11 +14,12 @@
ChatMessage(text="I am having trouble signing up for your product with my Google account.", role=RoleType.USER),
]


client = AI21Client()
response = client.chat.create(
system=system,
messages=messages,
model="j2-ultra",
model="j2-mid",
count_penalty=Penalty(
scale=0,
apply_to_emojis=False,
Expand Down
Empty file.
28 changes: 28 additions & 0 deletions examples/studio/chat/chat_completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from ai21 import AI21Client
from ai21.models import RoleType
from ai21.models.chat import ChatMessage

system = "You're a support engineer in a SaaS company"
messages = [
ChatMessage(content="Hello, I need help with a signup process.", role=RoleType.USER),
ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT),
ChatMessage(content="I am having trouble signing up for your product with my Google account.", role=RoleType.USER),
]

client = AI21Client()

response = client.chat.completions.create(
messages=messages,
model="new-model-name",
n=2,
logprobs=True,
top_logprobs=2,
max_tokens=100,
temperature=0.7,
top_p=1.0,
stop=["\n"],
frequency_penalty=0.1,
presence_penalty=0.1,
)

print(response)
39 changes: 39 additions & 0 deletions tests/integration_tests/clients/studio/test_chat_completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from ai21 import AI21Client
from ai21.models.chat import ChatMessage
from ai21.models import RoleType
from ai21.models.chat.chat_completion_response import ChatCompletionResponse


_MODEL = "new-model-name"
_MESSAGES = [
ChatMessage(
content="Hello, I need help studying for the coming test, can you teach me about the US constitution? ",
role=RoleType.USER,
),
]


# TODO: When the api is officially released, update the test to assert the actual response
@pytest.mark.skip(reason="API is not officially released")
def test_chat_completion():
num_results = 5
messages = _MESSAGES

client = AI21Client()
response = client.chat.completions.create(
model=_MODEL,
messages=messages,
num_results=num_results,
max_tokens=64,
logprobs=True,
top_logprobs=0.6,
temperature=0.7,
stop=["\n"],
top_p=0.3,
frequency_penalty=0.2,
presence_penalty=0.4,
)

assert isinstance(response, ChatCompletionResponse)
1 change: 1 addition & 0 deletions tests/integration_tests/clients/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Run this script after setting the environment variable called AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY
"""

import subprocess
from pathlib import Path

Expand Down
Loading