diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index 50ce63de..dd929ccc 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -29,6 +29,7 @@ def create( presence_penalty: Penalty | NotGiven = NOT_GIVEN, count_penalty: Penalty | NotGiven = NOT_GIVEN, epoch: int | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: """ @@ -46,6 +47,9 @@ def create( :param presence_penalty: A penalty applied to tokens that are already present in the prompt. :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses :param epoch: + :param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text + representations of the tokens and the floats are the biases themselves. A positive bias increases generation + probability for a given token and a negative bias decreases it. :param kwargs: :return: """ @@ -70,6 +74,7 @@ def _create_body( presence_penalty: Penalty | NotGiven, count_penalty: Penalty | NotGiven, epoch: int | NotGiven, + logit_bias: Dict[str, float] | NotGiven, ): return remove_not_given( { @@ -87,5 +92,6 @@ def _create_body( "presencePenalty": NOT_GIVEN if presence_penalty is NOT_GIVEN else presence_penalty.to_dict(), "countPenalty": NOT_GIVEN if count_penalty is NOT_GIVEN else count_penalty.to_dict(), "epoch": epoch, + "logitBias": logit_bias, } ) diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index 06a8166a..377cc4bd 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource from ai21.models import Penalty, CompletionsResponse @@ -21,8 +21,27 @@ def create( frequency_penalty: Penalty | NotGiven = NOT_GIVEN, presence_penalty: Penalty | NotGiven = NOT_GIVEN, count_penalty: Penalty | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: + """ + :param prompt: Text for model to complete + :param max_tokens: The maximum number of tokens to generate per result + :param num_results: Number of completions to sample and return. + :param min_tokens: The minimum number of tokens to generate per result. + :param temperature: A value controlling the "creativity" of the model's responses. + :param top_p: A value controlling the diversity of the model's responses. + :param top_k_return: The number of top-scoring tokens to consider for each generation step. + :param stop_sequences: Stops decoding if any of the strings is generated + :param frequency_penalty: A penalty applied to tokens that are frequently generated. + :param presence_penalty: A penalty applied to tokens that are already present in the prompt. + :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses + :param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text + representations of the tokens and the floats are the biases themselves. A positive bias increases generation + probability for a given token and a negative bias decreases it. + :param kwargs: + :return: + """ body = remove_not_given( { "prompt": prompt, @@ -36,6 +55,7 @@ def create( "frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty, "presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty, "countPenalty": count_penalty.to_dict() if count_penalty else count_penalty, + "logitBias": logit_bias, } ) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 513b4d1a..7741130f 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List +from typing import List, Dict from ai21.clients.common.completion_base import Completion from ai21.clients.studio.resources.studio_resource import StudioResource @@ -26,6 +26,7 @@ def create( presence_penalty: Penalty | NotGiven = NOT_GIVEN, count_penalty: Penalty | NotGiven = NOT_GIVEN, epoch: int | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: url = f"{self._client.get_base_url()}/{model}" @@ -49,5 +50,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, epoch=epoch, + logit_bias=logit_bias, ) return self._json_to_response(self._post(url=url, body=body)) diff --git a/examples/studio/completion.py b/examples/studio/completion.py index b1d0715d..1f21483a 100644 --- a/examples/studio/completion.py +++ b/examples/studio/completion.py @@ -44,6 +44,7 @@ num_results=1, custom_model=None, epoch=1, + logit_bias={"▁I'm▁sorry": -100.0}, count_penalty=Penalty( scale=0, apply_to_emojis=False, diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py index 9938fa2a..2ff2841c 100644 --- a/tests/integration_tests/clients/studio/test_completion.py +++ b/tests/integration_tests/clients/studio/test_completion.py @@ -1,5 +1,6 @@ import pytest +from typing import Dict from ai21 import AI21Client from ai21.models import Penalty @@ -25,6 +26,7 @@ def test_completion(): num_results=num_results, custom_model=None, epoch=1, + logit_bias={"▁a▁box▁of": -100.0}, count_penalty=Penalty( scale=0, apply_to_emojis=False, @@ -110,3 +112,24 @@ def test_completion_when_finish_reason_defined__should_halt_on_expected_reason( ) assert response.completions[0].finish_reason.reason == reason + + +@pytest.mark.parametrize( + ids=[ + "no_logit_bias", + "logit_bias_negative", + ], + argnames=["expected_result", "logit_bias"], + argvalues=[(" a box of chocolates", None), (" riding a bicycle", {"▁a▁box▁of": -100.0})], +) +def test_completion_logit_bias__should_impact_on_response(expected_result: str, logit_bias: Dict[str, float]): + client = AI21Client() + response = client.completion.create( + prompt="Life is like", + max_tokens=3, + model="j2-ultra", + temperature=0, + logit_bias=logit_bias, + ) + + assert response.completions[0].data.text.strip() == expected_result.strip()