From 86b7ab0b9d176897b65cf127ee80c5ace18204a6 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 21 Feb 2024 16:01:25 +0200 Subject: [PATCH 1/6] fix: penalties in sagemaker --- .../sagemaker/resources/sagemaker_completion.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index d850eca4..8da0fa8a 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -30,10 +30,17 @@ def create( "topP": top_p, "topKReturn": top_k_return, "stopSequences": stop_sequences or [], - "frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(), - "presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(), - "countPenalty": None if count_penalty is None else count_penalty.to_dict(), } + + if frequency_penalty is not None: + body["frequencyPenalty"] = frequency_penalty.to_dict() + + if presence_penalty is not None: + body["presencePenalty"] = presence_penalty.to_dict() + + if count_penalty is not None: + body["countPenalty"] = count_penalty.to_dict() + raw_response = self._invoke(body) return CompletionsResponse.from_dict(raw_response) From 3567d948a55b623fd5a20c64bbf95137aab8920f Mon Sep 17 00:00:00 2001 From: etang Date: Wed, 21 Feb 2024 19:33:31 +0200 Subject: [PATCH 2/6] fix: don't pass None penalties to Bedrock --- .../clients/bedrock/resources/bedrock_completion.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index e8617342..799ac09e 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -31,10 +31,17 @@ def create( "topP": top_p, "topKReturn": top_k_return, "stopSequences": stop_sequences or [], - "frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(), - "presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(), - "countPenalty": None if count_penalty is None else count_penalty.to_dict(), } + + if frequency_penalty is not None: + body["frequencyPenalty"] = frequency_penalty.to_dict() + + if presence_penalty is not None: + body["presencePenalty"] = presence_penalty.to_dict() + + if count_penalty is not None: + body["countPenalty"] = count_penalty.to_dict() + raw_response = self._invoke(model_id=model_id, body=body) return CompletionsResponse.from_dict(raw_response) From 17dd43c76d41ff0e7be6446e468b1fe238171d2a Mon Sep 17 00:00:00 2001 From: etang Date: Wed, 21 Feb 2024 19:35:17 +0200 Subject: [PATCH 3/6] fix: remove some default arge, and some unused args from bedrock model --- examples/bedrock/completion.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/examples/bedrock/completion.py b/examples/bedrock/completion.py index def48813..59b39862 100644 --- a/examples/bedrock/completion.py +++ b/examples/bedrock/completion.py @@ -1,5 +1,4 @@ from ai21 import AI21BedrockClient, BedrockModelID -from ai21.models import Penalty # Bedrock is currently supported only in us-east-1 region. # Either set your profile's region to us-east-1 or uncomment next line @@ -46,34 +45,6 @@ temperature=0, top_p=1, top_k_return=0, - stop_sequences=["##"], - num_results=1, - custom_model=None, - epoch=1, - count_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - frequency_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - presence_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), ) print(response.completions[0].data.text) From 8436ea9011e49bcd37723a9c273b4b5cd7960669 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 21 Feb 2024 20:44:00 +0200 Subject: [PATCH 4/6] test: Added bedrock integration tests for penalties check --- .../clients/bedrock/__init__.py | 0 .../clients/bedrock/test_completion.py | 71 +++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 tests/integration_tests/clients/bedrock/__init__.py create mode 100644 tests/integration_tests/clients/bedrock/test_completion.py diff --git a/tests/integration_tests/clients/bedrock/__init__.py b/tests/integration_tests/clients/bedrock/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/clients/bedrock/test_completion.py b/tests/integration_tests/clients/bedrock/test_completion.py new file mode 100644 index 00000000..26ea85b3 --- /dev/null +++ b/tests/integration_tests/clients/bedrock/test_completion.py @@ -0,0 +1,71 @@ +from typing import Optional + +import pytest + +from ai21 import AI21BedrockClient +from ai21.clients.bedrock.bedrock_model_id import BedrockModelID +from ai21.models import Penalty +from tests.integration_tests.skip_helpers import should_skip_bedrock_integration_tests + +_PROMPT = "Once upon a time, in a land far, far away, there was a" + + +@pytest.mark.skipif(should_skip_bedrock_integration_tests(), reason="No keys supplied for AWS. Skipping.") +@pytest.mark.parametrize( + ids=[ + "when_no_penalties__should_return_response", + "when_penalties__should_return_response", + ], + argnames=["frequency_penalty", "presence_penalty", "count_penalty"], + argvalues=[ + (None, None, None), + ( + Penalty( + scale=0.5, + apply_to_emojis=True, + apply_to_numbers=True, + apply_to_stopwords=True, + apply_to_punctuation=True, + apply_to_whitespaces=True, + ), + Penalty( + scale=0.5, + apply_to_emojis=True, + apply_to_numbers=True, + apply_to_stopwords=True, + apply_to_punctuation=True, + apply_to_whitespaces=True, + ), + Penalty( + scale=0.5, + apply_to_emojis=True, + apply_to_numbers=True, + apply_to_stopwords=True, + apply_to_punctuation=True, + apply_to_whitespaces=True, + ), + ), + ], +) +def test_completion__when_no_penalties__should_return_response( + frequency_penalty: Optional[Penalty], presence_penalty: Optional[Penalty], count_penalty: Optional[Penalty] +): + client = AI21BedrockClient() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model_id=BedrockModelID.J2_MID_V1, + temperature=0, + top_p=1, + top_k_return=0, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + count_penalty=count_penalty, + ) + + assert response.prompt.text == _PROMPT + assert len(response.completions) == 1 + # Check the results aren't all the same + assert len([completion.data.text for completion in response.completions]) == 1 + for completion in response.completions: + assert isinstance(completion.data.text, str) From f12349591d2db6200f9552d6bd3f7be0912f9eec Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 21 Feb 2024 20:44:38 +0200 Subject: [PATCH 5/6] ci: Integration tests on push --- .github/workflows/integration-tests.yaml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 4b63cf92..1e9a8cea 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -1,10 +1,6 @@ name: Integration Tests -on: - push: - branches: - - main - - "rc_*" +on: [push] env: POETRY_VERSION: "1.7.1" From 69074107800fee4c7f4ccdeb50883fa50f8a5872 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 21 Feb 2024 20:47:05 +0200 Subject: [PATCH 6/6] fix: answer test --- examples/studio/answer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/studio/answer.py b/examples/studio/answer.py index 10659ddb..2d1a7c8a 100644 --- a/examples/studio/answer.py +++ b/examples/studio/answer.py @@ -1,5 +1,4 @@ from ai21 import AI21Client -from ai21.models import Mode, AnswerLength client = AI21Client() @@ -10,7 +9,5 @@ "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " "economic power, dominating the other provinces of the newly independent Dutch Republic.", question="When did Holland become an economic power?", - answer_length=AnswerLength.LONG, - mode=Mode.FLEXIBLE, ) print(response)