-
Notifications
You must be signed in to change notification settings - Fork 12
feat: support chat completion in studio SDK #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
debb61a
feat: support chat completion in studio SDK
asafgardin ca5499f
refactor: Moved chat message to different packages
asafgardin 76d2fad
fix: Added __all__
asafgardin a59e503
refactor: imports and file structure
asafgardin 05e4f67
docs: Updated todo
asafgardin 4f4a291
fix: imports
asafgardin 21d3da8
fix: Added deprecation warning
asafgardin 4bf975e
test: Added a unittest
asafgardin 2e7b836
fix: CR
asafgardin 2a2e1d3
fix: CR
asafgardin ad2e205
fix: model name
asafgardin e7aac37
fix: alias
asafgardin 27911dd
revert: ruff
asafgardin debe0de
fix: circualr imports
asafgardin d200d0e
fix: all import
asafgardin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from .chat_completions import ChatCompletions as ChatCompletions |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import List, Optional, Union, Any, Dict | ||
|
|
||
| from ai21.clients.studio.resources.studio_resource import StudioResource | ||
| from ai21.models.chat import ChatMessage, ChatCompletionResponse | ||
| from ai21.types import NotGiven, NOT_GIVEN | ||
| from ai21.utils.typing import remove_not_given | ||
|
|
||
| __all__ = ["ChatCompletions"] | ||
|
|
||
|
|
||
| class ChatCompletions(StudioResource): | ||
| _module_name = "chat/complete" | ||
|
|
||
| def create( | ||
| self, | ||
| model: str, | ||
| messages: List[ChatMessage], | ||
| n: int | NotGiven = NOT_GIVEN, | ||
| logprobs: bool | NotGiven = NOT_GIVEN, | ||
| top_logprobs: int | NotGiven = NOT_GIVEN, | ||
| max_tokens: int | NotGiven = NOT_GIVEN, | ||
| temperature: float | NotGiven = NOT_GIVEN, | ||
| top_p: float | NotGiven = NOT_GIVEN, | ||
| stop: str | List[str] | NotGiven = NOT_GIVEN, | ||
| frequency_penalty: float | NotGiven = NOT_GIVEN, | ||
| presence_penalty: float | NotGiven = NOT_GIVEN, | ||
| **kwargs: Any, | ||
| ) -> ChatCompletionResponse: | ||
| body = self._create_body( | ||
| model=model, | ||
| messages=messages, | ||
| n=n, | ||
| logprobs=logprobs, | ||
| top_logprobs=top_logprobs, | ||
| stop=stop, | ||
| temperature=temperature, | ||
| max_tokens=max_tokens, | ||
| top_p=top_p, | ||
| frequency_penalty=frequency_penalty, | ||
| presence_penalty=presence_penalty, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| url = f"{self._client.get_base_url()}/{self._module_name}" | ||
| response = self._post(url=url, body=body) | ||
| return self._json_to_response(response) | ||
|
|
||
| def _create_body( | ||
| self, | ||
| model: str, | ||
| messages: List[ChatMessage], | ||
| n: Optional[int] | NotGiven, | ||
| logprobs: Optional[bool] | NotGiven, | ||
| top_logprobs: Optional[int] | NotGiven, | ||
| max_tokens: Optional[int] | NotGiven, | ||
| temperature: Optional[float] | NotGiven, | ||
| top_p: Optional[float] | NotGiven, | ||
| stop: Optional[Union[str, List[str]]] | NotGiven, | ||
| frequency_penalty: Optional[float] | NotGiven, | ||
| presence_penalty: Optional[float] | NotGiven, | ||
| **kwargs: Any, | ||
| ) -> Dict[str, Any]: | ||
| return remove_not_given( | ||
| { | ||
| "model": model, | ||
| "messages": [message.to_dict() for message in messages], | ||
| "temperature": temperature, | ||
| "maxTokens": max_tokens, | ||
| "n": n, | ||
| "topP": top_p, | ||
| "logprobs": logprobs, | ||
| "topLogprobs": top_logprobs, | ||
| "stop": stop, | ||
| "frequencyPenalty": frequency_penalty, | ||
| "presencePenalty": presence_penalty, | ||
| **kwargs, | ||
| } | ||
| ) | ||
|
|
||
| def _json_to_response(self, json: Dict[str, Any]) -> ChatCompletionResponse: | ||
| return ChatCompletionResponse.from_dict(json) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from .chat_completion_response import ChatCompletionResponse | ||
| from .chat_completion_response import ChatCompletionResponseChoice | ||
| from .chat_message import ChatMessage | ||
| from .role_type import RoleType as RoleType | ||
|
|
||
| __all__ = ["ChatCompletionResponse", "ChatCompletionResponseChoice", "ChatMessage", "RoleType"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| from dataclasses import dataclass | ||
| from typing import Optional, List | ||
|
|
||
| from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin | ||
| from ai21.models.logprobs import Logprobs | ||
| from ai21.models.usage_info import UsageInfo | ||
| from .chat_message import ChatMessage | ||
|
|
||
|
|
||
| @dataclass | ||
| class ChatCompletionResponseChoice(AI21BaseModelMixin): | ||
| index: int | ||
| message: ChatMessage | ||
| logprobs: Optional[Logprobs] = None | ||
| finish_reason: Optional[str] = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class ChatCompletionResponse(AI21BaseModelMixin): | ||
| id: str | ||
| choices: List[ChatCompletionResponseChoice] | ||
| usage: UsageInfo |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin | ||
| from .role_type import RoleType | ||
|
|
||
|
|
||
| @dataclass | ||
| class ChatMessage(AI21BaseModelMixin): | ||
| role: RoleType | ||
| content: str |
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| from dataclasses import dataclass | ||
| from typing import List | ||
|
|
||
| from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin | ||
|
|
||
|
|
||
| @dataclass | ||
| class TopTokenData(AI21BaseModelMixin): | ||
| token: str | ||
| logprob: float | ||
|
|
||
|
|
||
| @dataclass | ||
| class LogprobsData(AI21BaseModelMixin): | ||
| token: str | ||
| logprob: float | ||
| top_logprobs: List[TopTokenData] | ||
|
|
||
|
|
||
| @dataclass | ||
| class Logprobs(AI21BaseModelMixin): | ||
| content: LogprobsData |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin | ||
|
|
||
|
|
||
| @dataclass | ||
| class UsageInfo(AI21BaseModelMixin): | ||
| prompt_tokens: int | ||
| completion_tokens: int | ||
| total_tokens: int |
Empty file.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| from ai21 import AI21Client | ||
| from ai21.models import RoleType | ||
| from ai21.models.chat import ChatMessage | ||
|
|
||
| system = "You're a support engineer in a SaaS company" | ||
| messages = [ | ||
| ChatMessage(content="Hello, I need help with a signup process.", role=RoleType.USER), | ||
| ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT), | ||
| ChatMessage(content="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), | ||
| ] | ||
|
|
||
| client = AI21Client() | ||
|
|
||
| response = client.chat.completions.create( | ||
| messages=messages, | ||
| model="new-model-name", | ||
| n=2, | ||
| logprobs=True, | ||
| top_logprobs=2, | ||
| max_tokens=100, | ||
| temperature=0.7, | ||
| top_p=1.0, | ||
| stop=["\n"], | ||
| frequency_penalty=0.1, | ||
| presence_penalty=0.1, | ||
| ) | ||
|
|
||
| print(response) |
39 changes: 39 additions & 0 deletions
39
tests/integration_tests/clients/studio/test_chat_completions.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| import pytest | ||
|
|
||
| from ai21 import AI21Client | ||
| from ai21.models.chat import ChatMessage | ||
| from ai21.models import RoleType | ||
| from ai21.models.chat.chat_completion_response import ChatCompletionResponse | ||
|
|
||
|
|
||
| _MODEL = "new-model-name" | ||
| _MESSAGES = [ | ||
| ChatMessage( | ||
| content="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", | ||
| role=RoleType.USER, | ||
| ), | ||
| ] | ||
|
|
||
|
|
||
| # TODO: When the api is officially released, update the test to assert the actual response | ||
| @pytest.mark.skip(reason="API is not officially released") | ||
| def test_chat_completion(): | ||
| num_results = 5 | ||
| messages = _MESSAGES | ||
|
|
||
| client = AI21Client() | ||
| response = client.chat.completions.create( | ||
| model=_MODEL, | ||
| messages=messages, | ||
| num_results=num_results, | ||
| max_tokens=64, | ||
| logprobs=True, | ||
| top_logprobs=0.6, | ||
| temperature=0.7, | ||
| stop=["\n"], | ||
| top_p=0.3, | ||
| frequency_penalty=0.2, | ||
| presence_penalty=0.4, | ||
| ) | ||
|
|
||
| assert isinstance(response, ChatCompletionResponse) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.