diff --git a/ai21/clients/studio/resources/chat/chat_completions.py b/ai21/clients/studio/resources/chat/chat_completions.py index 69cec6a7..44dab3b5 100644 --- a/ai21/clients/studio/resources/chat/chat_completions.py +++ b/ai21/clients/studio/resources/chat/chat_completions.py @@ -21,6 +21,7 @@ def create( temperature: float | NotGiven = NOT_GIVEN, top_p: float | NotGiven = NOT_GIVEN, stop: str | List[str] | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> ChatCompletionResponse: body = self._create_body( @@ -30,6 +31,7 @@ def create( temperature=temperature, max_tokens=max_tokens, top_p=top_p, + n=n, **kwargs, ) @@ -45,6 +47,7 @@ def _create_body( temperature: Optional[float] | NotGiven, top_p: Optional[float] | NotGiven, stop: Optional[Union[str, List[str]]] | NotGiven, + n: Optional[int] | NotGiven, **kwargs: Any, ) -> Dict[str, Any]: return remove_not_given( @@ -55,6 +58,7 @@ def _create_body( "maxTokens": max_tokens, "topP": top_p, "stop": stop, + "n": n, **kwargs, } ) diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index 9faf364b..684f0aa5 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -28,3 +28,25 @@ def test_chat_completion(): assert isinstance(response, ChatCompletionResponse) assert response.choices[0].message.content assert response.choices[0].message.role + + +def test_chat_completion__with_n_param__should_return_n_choices(): + messages = _MESSAGES + n = 3 + + client = AI21Client() + response = client.chat.completions.create( + model=_MODEL, + messages=messages, + max_tokens=64, + temperature=0.7, + stop=["\n"], + top_p=0.3, + n=n, + ) + + assert isinstance(response, ChatCompletionResponse) + assert len(response.choices) == n + for choice in response.choices: + assert choice.message.content + assert choice.message.role