From 4e39a775227d5004ca943fc437535e67d2691ba1 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Wed, 21 Feb 2024 09:07:00 +0200 Subject: [PATCH 01/13] fix: add logit bias, fix studio completion --- ai21/clients/common/completion_base.py | 6 ++++++ ai21/clients/studio/resources/studio_completion.py | 4 +++- examples/studio/answer.py | 3 --- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index abc338f8..42e0b753 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -25,6 +25,7 @@ def create( presence_penalty: Optional[Penalty] = None, count_penalty: Optional[Penalty] = None, epoch: Optional[int] = None, + logit_bias: Optional[Dict[str, float]] = None, **kwargs, ) -> CompletionsResponse: """ @@ -42,6 +43,9 @@ def create( :param presence_penalty: A penalty applied to tokens that are already present in the prompt. :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses :param epoch: + :param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text + representations of the tokens and the floats are the biases themselves. A positive bias increases generation + probability for a given token and a negative bias decreases it. :param kwargs: :return: """ @@ -66,6 +70,7 @@ def _create_body( presence_penalty: Optional[Penalty], count_penalty: Optional[Penalty], epoch: Optional[int], + logit_bias: Optional[Dict[str, float]], ): return { "model": model, @@ -82,4 +87,5 @@ def _create_body( "presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(), "countPenalty": None if count_penalty is None else count_penalty.to_dict(), "epoch": epoch, + "logitBias": logit_bias, } diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 3b2cfc77..a4471167 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Dict from ai21.clients.common.completion_base import Completion from ai21.clients.studio.resources.studio_resource import StudioResource @@ -23,6 +23,7 @@ def create( presence_penalty: Optional[Penalty] = None, count_penalty: Optional[Penalty] = None, epoch: Optional[int] = None, + logit_bias: Optional[Dict[str, float]] = None, **kwargs, ) -> CompletionsResponse: url = f"{self._client.get_base_url()}/{model}" @@ -46,5 +47,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, epoch=epoch, + logit_bias=logit_bias, ) return self._json_to_response(self._post(url=url, body=body)) diff --git a/examples/studio/answer.py b/examples/studio/answer.py index 10659ddb..2d1a7c8a 100644 --- a/examples/studio/answer.py +++ b/examples/studio/answer.py @@ -1,5 +1,4 @@ from ai21 import AI21Client -from ai21.models import Mode, AnswerLength client = AI21Client() @@ -10,7 +9,5 @@ "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " "economic power, dominating the other provinces of the newly independent Dutch Republic.", question="When did Holland become an economic power?", - answer_length=AnswerLength.LONG, - mode=Mode.FLEXIBLE, ) print(response) From a6710b379c2659c917970c83388d9593d41835e6 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Wed, 21 Feb 2024 15:28:47 +0200 Subject: [PATCH 02/13] fix: adjust tests --- tests/integration_tests/clients/studio/test_completion.py | 1 + tests/unittests/clients/studio/resources/conftest.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py index 9938fa2a..f760d477 100644 --- a/tests/integration_tests/clients/studio/test_completion.py +++ b/tests/integration_tests/clients/studio/test_completion.py @@ -25,6 +25,7 @@ def test_completion(): num_results=num_results, custom_model=None, epoch=1, + logit_bias={"▁a▁box▁of": -100.0}, count_penalty=Penalty( scale=0, apply_to_emojis=False, diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 00e2088d..cbb6afe4 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -124,6 +124,7 @@ def get_studio_completion(): "maxTokens": None, "minTokens": 0, "epoch": None, + "logitBias": None, "numResults": 1, "topP": 1, "customModel": None, From 3effc4049cfce27081573d2d52363aee12496cea Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Wed, 21 Feb 2024 16:57:17 +0200 Subject: [PATCH 03/13] fix: add logit bias integration test --- .../clients/studio/test_completion.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py index f760d477..2ff2841c 100644 --- a/tests/integration_tests/clients/studio/test_completion.py +++ b/tests/integration_tests/clients/studio/test_completion.py @@ -1,5 +1,6 @@ import pytest +from typing import Dict from ai21 import AI21Client from ai21.models import Penalty @@ -111,3 +112,24 @@ def test_completion_when_finish_reason_defined__should_halt_on_expected_reason( ) assert response.completions[0].finish_reason.reason == reason + + +@pytest.mark.parametrize( + ids=[ + "no_logit_bias", + "logit_bias_negative", + ], + argnames=["expected_result", "logit_bias"], + argvalues=[(" a box of chocolates", None), (" riding a bicycle", {"▁a▁box▁of": -100.0})], +) +def test_completion_logit_bias__should_impact_on_response(expected_result: str, logit_bias: Dict[str, float]): + client = AI21Client() + response = client.completion.create( + prompt="Life is like", + max_tokens=3, + model="j2-ultra", + temperature=0, + logit_bias=logit_bias, + ) + + assert response.completions[0].data.text.strip() == expected_result.strip() From 6c4db477ff4a76d4daa8c617105ed7856ce28ad3 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Wed, 21 Feb 2024 17:41:16 +0200 Subject: [PATCH 04/13] fix: update studio completion example --- examples/studio/completion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/studio/completion.py b/examples/studio/completion.py index b1d0715d..22520d74 100644 --- a/examples/studio/completion.py +++ b/examples/studio/completion.py @@ -44,6 +44,7 @@ num_results=1, custom_model=None, epoch=1, + logit_bias={"_some_str_to_avoid", -100.0}, count_penalty=Penalty( scale=0, apply_to_emojis=False, From 0b2f87e75f04121a0895687ad662335ff998fa41 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Wed, 21 Feb 2024 18:03:16 +0200 Subject: [PATCH 05/13] fix: fix studio completion example --- examples/studio/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/studio/completion.py b/examples/studio/completion.py index 22520d74..1f21483a 100644 --- a/examples/studio/completion.py +++ b/examples/studio/completion.py @@ -44,7 +44,7 @@ num_results=1, custom_model=None, epoch=1, - logit_bias={"_some_str_to_avoid", -100.0}, + logit_bias={"▁I'm▁sorry": -100.0}, count_penalty=Penalty( scale=0, apply_to_emojis=False, From c3aabfc0df2e918995f9137a919dc3f5c5141a58 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Thu, 22 Feb 2024 09:33:28 +0200 Subject: [PATCH 06/13] fix: fix studio completion example --- ai21/clients/bedrock/resources/bedrock_completion.py | 4 +++- examples/bedrock/completion.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index e8617342..64dfd3de 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Dict from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource from ai21.models import Penalty, CompletionsResponse @@ -20,6 +20,7 @@ def create( frequency_penalty: Optional[Penalty] = None, presence_penalty: Optional[Penalty] = None, count_penalty: Optional[Penalty] = None, + logit_bias: Optional[Dict[str, float]] = None, **kwargs, ) -> CompletionsResponse: body = { @@ -34,6 +35,7 @@ def create( "frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(), "presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(), "countPenalty": None if count_penalty is None else count_penalty.to_dict(), + "logitBias": logit_bias, } raw_response = self._invoke(model_id=model_id, body=body) diff --git a/examples/bedrock/completion.py b/examples/bedrock/completion.py index def48813..c283c34d 100644 --- a/examples/bedrock/completion.py +++ b/examples/bedrock/completion.py @@ -50,6 +50,7 @@ num_results=1, custom_model=None, epoch=1, + logit_bias={"▁I'm▁sorry": -100.0}, count_penalty=Penalty( scale=0, apply_to_emojis=False, From 1bdb30c7a7b678bcc97f2d43ac30366f5c568c13 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Thu, 22 Feb 2024 17:23:12 +0200 Subject: [PATCH 07/13] fix: remove logit bias from bedrock --- ai21/clients/bedrock/resources/bedrock_completion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index 5a6ad834..799ac09e 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict +from typing import Optional, List from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource from ai21.models import Penalty, CompletionsResponse @@ -20,7 +20,6 @@ def create( frequency_penalty: Optional[Penalty] = None, presence_penalty: Optional[Penalty] = None, count_penalty: Optional[Penalty] = None, - logit_bias: Optional[Dict[str, float]] = None, **kwargs, ) -> CompletionsResponse: body = { @@ -32,7 +31,6 @@ def create( "topP": top_p, "topKReturn": top_k_return, "stopSequences": stop_sequences or [], - "logitBias": logit_bias, } if frequency_penalty is not None: From ebc2c3781518c8579083c4e6782c60f2feacb113 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Thu, 22 Feb 2024 18:43:31 +0200 Subject: [PATCH 08/13] fix: add logit bias to sagemaker completion, add params string --- .../resources/sagemaker_completion.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index 8da0fa8a..36e43a05 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Dict from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource from ai21.models import Penalty, CompletionsResponse @@ -19,8 +19,27 @@ def create( frequency_penalty: Optional[Penalty] = None, presence_penalty: Optional[Penalty] = None, count_penalty: Optional[Penalty] = None, + logit_bias: Optional[Dict[str, float]] = None, **kwargs, ) -> CompletionsResponse: + """ + :param prompt: Text for model to complete + :param max_tokens: The maximum number of tokens to generate per result + :param num_results: Number of completions to sample and return. + :param min_tokens: The minimum number of tokens to generate per result. + :param temperature: A value controlling the "creativity" of the model's responses. + :param top_p: A value controlling the diversity of the model's responses. + :param top_k_return: The number of top-scoring tokens to consider for each generation step. + :param stop_sequences: Stops decoding if any of the strings is generated + :param frequency_penalty: A penalty applied to tokens that are frequently generated. + :param presence_penalty: A penalty applied to tokens that are already present in the prompt. + :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses + :param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text + representations of the tokens and the floats are the biases themselves. A positive bias increases generation + probability for a given token and a negative bias decreases it. + :param kwargs: + :return: + """ body = { "prompt": prompt, "maxTokens": max_tokens, @@ -41,6 +60,9 @@ def create( if count_penalty is not None: body["countPenalty"] = count_penalty.to_dict() + if logit_bias is not None: + body["logitBias"] = logit_bias + raw_response = self._invoke(body) return CompletionsResponse.from_dict(raw_response) From 4e8b5944cf16fbbdfb19afb8ea12dc4c8a688b34 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Wed, 21 Feb 2024 15:28:47 +0200 Subject: [PATCH 09/13] fix: adjust tests --- tests/integration_tests/clients/studio/test_completion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py index 9938fa2a..f760d477 100644 --- a/tests/integration_tests/clients/studio/test_completion.py +++ b/tests/integration_tests/clients/studio/test_completion.py @@ -25,6 +25,7 @@ def test_completion(): num_results=num_results, custom_model=None, epoch=1, + logit_bias={"▁a▁box▁of": -100.0}, count_penalty=Penalty( scale=0, apply_to_emojis=False, From 703bb658736ba1772d7bd19122870c01b5968e57 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Wed, 21 Feb 2024 16:57:17 +0200 Subject: [PATCH 10/13] fix: add logit bias integration test --- .../clients/studio/test_completion.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py index f760d477..2ff2841c 100644 --- a/tests/integration_tests/clients/studio/test_completion.py +++ b/tests/integration_tests/clients/studio/test_completion.py @@ -1,5 +1,6 @@ import pytest +from typing import Dict from ai21 import AI21Client from ai21.models import Penalty @@ -111,3 +112,24 @@ def test_completion_when_finish_reason_defined__should_halt_on_expected_reason( ) assert response.completions[0].finish_reason.reason == reason + + +@pytest.mark.parametrize( + ids=[ + "no_logit_bias", + "logit_bias_negative", + ], + argnames=["expected_result", "logit_bias"], + argvalues=[(" a box of chocolates", None), (" riding a bicycle", {"▁a▁box▁of": -100.0})], +) +def test_completion_logit_bias__should_impact_on_response(expected_result: str, logit_bias: Dict[str, float]): + client = AI21Client() + response = client.completion.create( + prompt="Life is like", + max_tokens=3, + model="j2-ultra", + temperature=0, + logit_bias=logit_bias, + ) + + assert response.completions[0].data.text.strip() == expected_result.strip() From d79ae02ca318933290f2d1e4165b61ea9c200fcb Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Wed, 21 Feb 2024 17:41:16 +0200 Subject: [PATCH 11/13] fix: update studio completion example --- examples/studio/completion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/studio/completion.py b/examples/studio/completion.py index b1d0715d..22520d74 100644 --- a/examples/studio/completion.py +++ b/examples/studio/completion.py @@ -44,6 +44,7 @@ num_results=1, custom_model=None, epoch=1, + logit_bias={"_some_str_to_avoid", -100.0}, count_penalty=Penalty( scale=0, apply_to_emojis=False, From 952b0043e53f9675afc08536465d434cd6748e94 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Wed, 21 Feb 2024 18:03:16 +0200 Subject: [PATCH 12/13] fix: fix studio completion example --- examples/studio/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/studio/completion.py b/examples/studio/completion.py index 22520d74..1f21483a 100644 --- a/examples/studio/completion.py +++ b/examples/studio/completion.py @@ -44,7 +44,7 @@ num_results=1, custom_model=None, epoch=1, - logit_bias={"_some_str_to_avoid", -100.0}, + logit_bias={"▁I'm▁sorry": -100.0}, count_penalty=Penalty( scale=0, apply_to_emojis=False, From b6589e1d1a7904d19cfb7f9479ba3931a8b71b2b Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Sun, 25 Feb 2024 16:29:15 +0200 Subject: [PATCH 13/13] fix: update code with new not_giving approach --- ai21/clients/common/completion_base.py | 6 +++++ .../resources/sagemaker_completion.py | 22 ++++++++++++++++++- .../studio/resources/studio_completion.py | 4 +++- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index 50ce63de..dd929ccc 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -29,6 +29,7 @@ def create( presence_penalty: Penalty | NotGiven = NOT_GIVEN, count_penalty: Penalty | NotGiven = NOT_GIVEN, epoch: int | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: """ @@ -46,6 +47,9 @@ def create( :param presence_penalty: A penalty applied to tokens that are already present in the prompt. :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses :param epoch: + :param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text + representations of the tokens and the floats are the biases themselves. A positive bias increases generation + probability for a given token and a negative bias decreases it. :param kwargs: :return: """ @@ -70,6 +74,7 @@ def _create_body( presence_penalty: Penalty | NotGiven, count_penalty: Penalty | NotGiven, epoch: int | NotGiven, + logit_bias: Dict[str, float] | NotGiven, ): return remove_not_given( { @@ -87,5 +92,6 @@ def _create_body( "presencePenalty": NOT_GIVEN if presence_penalty is NOT_GIVEN else presence_penalty.to_dict(), "countPenalty": NOT_GIVEN if count_penalty is NOT_GIVEN else count_penalty.to_dict(), "epoch": epoch, + "logitBias": logit_bias, } ) diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index 06a8166a..377cc4bd 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource from ai21.models import Penalty, CompletionsResponse @@ -21,8 +21,27 @@ def create( frequency_penalty: Penalty | NotGiven = NOT_GIVEN, presence_penalty: Penalty | NotGiven = NOT_GIVEN, count_penalty: Penalty | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: + """ + :param prompt: Text for model to complete + :param max_tokens: The maximum number of tokens to generate per result + :param num_results: Number of completions to sample and return. + :param min_tokens: The minimum number of tokens to generate per result. + :param temperature: A value controlling the "creativity" of the model's responses. + :param top_p: A value controlling the diversity of the model's responses. + :param top_k_return: The number of top-scoring tokens to consider for each generation step. + :param stop_sequences: Stops decoding if any of the strings is generated + :param frequency_penalty: A penalty applied to tokens that are frequently generated. + :param presence_penalty: A penalty applied to tokens that are already present in the prompt. + :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses + :param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text + representations of the tokens and the floats are the biases themselves. A positive bias increases generation + probability for a given token and a negative bias decreases it. + :param kwargs: + :return: + """ body = remove_not_given( { "prompt": prompt, @@ -36,6 +55,7 @@ def create( "frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty, "presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty, "countPenalty": count_penalty.to_dict() if count_penalty else count_penalty, + "logitBias": logit_bias, } ) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 513b4d1a..7741130f 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List +from typing import List, Dict from ai21.clients.common.completion_base import Completion from ai21.clients.studio.resources.studio_resource import StudioResource @@ -26,6 +26,7 @@ def create( presence_penalty: Penalty | NotGiven = NOT_GIVEN, count_penalty: Penalty | NotGiven = NOT_GIVEN, epoch: int | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: url = f"{self._client.get_base_url()}/{model}" @@ -49,5 +50,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, epoch=epoch, + logit_bias=logit_bias, ) return self._json_to_response(self._post(url=url, body=body))