From 00770e386c1d5c23fb9c492ed62971fac1908fdf Mon Sep 17 00:00:00 2001 From: etang Date: Wed, 10 Apr 2024 10:50:07 +0300 Subject: [PATCH 1/4] fix: revert completion message to dataclass --- ai21/clients/studio/resources/chat/chat_completions.py | 2 +- ai21/models/chat/chat_message.py | 10 +++++++--- .../clients/studio/test_chat_completions.py | 9 ++------- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/ai21/clients/studio/resources/chat/chat_completions.py b/ai21/clients/studio/resources/chat/chat_completions.py index 503bc65b..69cec6a7 100644 --- a/ai21/clients/studio/resources/chat/chat_completions.py +++ b/ai21/clients/studio/resources/chat/chat_completions.py @@ -50,7 +50,7 @@ def _create_body( return remove_not_given( { "model": model, - "messages": messages, + "messages": [message.to_dict() for message in messages], "temperature": temperature, "maxTokens": max_tokens, "topP": top_p, diff --git a/ai21/models/chat/chat_message.py b/ai21/models/chat/chat_message.py index e26ad1af..959c5f46 100644 --- a/ai21/models/chat/chat_message.py +++ b/ai21/models/chat/chat_message.py @@ -1,7 +1,11 @@ from __future__ import annotations -from typing_extensions import TypedDict, Literal +from dataclasses import dataclass -class ChatMessage(TypedDict): - role: str | Literal["user", "assistant", "system"] +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin + + +@dataclass +class ChatMessage(AI21BaseModelMixin): + role: str content: str diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index 5aef4fff..c94e3a13 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -1,12 +1,10 @@ -import pytest - from ai21 import AI21Client from ai21.models.chat import ChatMessage from ai21.models import RoleType from ai21.models.chat.chat_completion_response import ChatCompletionResponse -_MODEL = "jamba-instruct-preview" +_MODEL = "jamba-instruct" _MESSAGES = [ ChatMessage( content="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", @@ -15,17 +13,13 @@ ] -# TODO: When the api is officially released, update the test to assert the actual response -@pytest.mark.skip(reason="API is not officially released") def test_chat_completion(): - num_results = 5 messages = _MESSAGES client = AI21Client() response = client.chat.completions.create( model=_MODEL, messages=messages, - num_results=num_results, max_tokens=64, temperature=0.7, stop=["\n"], @@ -33,3 +27,4 @@ def test_chat_completion(): ) assert isinstance(response, ChatCompletionResponse) + assert response.choices[0] From b3bd8d263f93640e5d53d860bd8bceefd5785120 Mon Sep 17 00:00:00 2001 From: etang Date: Wed, 10 Apr 2024 10:58:52 +0300 Subject: [PATCH 2/4] fix: test was broken --- tests/unittests/clients/studio/resources/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 1bcbdabc..132f95ad 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -126,6 +126,7 @@ def get_chat_completions(): role="assistant", ), ] + _EXPECTED_SERIALIZED_MESSAGES = [message.to_dict() for message in _DUMMY_MESSAGES] return ( ChatCompletions, @@ -133,7 +134,7 @@ def get_chat_completions(): "chat/completions", { "model": _DUMMY_MODEL, - "messages": _DUMMY_MESSAGES, + "messages": _EXPECTED_SERIALIZED_MESSAGES, }, ChatCompletionResponse( id="some-id", From fa11c719e43419abc6a75a3560bd0d653e1bdb55 Mon Sep 17 00:00:00 2001 From: etang Date: Wed, 10 Apr 2024 11:03:09 +0300 Subject: [PATCH 3/4] fix: skip test until release --- .../clients/studio/test_chat_completions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index c94e3a13..1ad6e3f9 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -1,10 +1,12 @@ +import pytest + from ai21 import AI21Client from ai21.models.chat import ChatMessage from ai21.models import RoleType from ai21.models.chat.chat_completion_response import ChatCompletionResponse -_MODEL = "jamba-instruct" +_MODEL = "jamba-instruct-preview" _MESSAGES = [ ChatMessage( content="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", @@ -13,6 +15,8 @@ ] +# TODO: When the api is officially released, update the test to assert the actual response +@pytest.mark.skip(reason="API is not officially released") def test_chat_completion(): messages = _MESSAGES From b3621e469956201c3e82cdbb8cfc6a8abb8e3edd Mon Sep 17 00:00:00 2001 From: etang Date: Wed, 10 Apr 2024 11:09:36 +0300 Subject: [PATCH 4/4] fix: skip test until release --- .../integration_tests/clients/studio/test_chat_completions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index 1ad6e3f9..2b6275ad 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -31,4 +31,5 @@ def test_chat_completion(): ) assert isinstance(response, ChatCompletionResponse) - assert response.choices[0] + assert response.choices[0].message.content + assert response.choices[0].message.role