From 51ba3ed155ab96260a5a6f2e5d767f8176604dee Mon Sep 17 00:00:00 2001 From: etang Date: Thu, 22 Feb 2024 13:59:20 +0200 Subject: [PATCH 1/6] feat: support NOT_GIVEN type --- .../bedrock/resources/bedrock_completion.py | 58 ++++++------ ai21/clients/common/completion_base.py | 88 ++++++++++--------- .../resources/sagemaker_completion.py | 58 ++++++------ .../studio/resources/studio_completion.py | 31 ++++--- ai21/models/ai21_base_model_mixin.py | 3 + ai21/models/penalty.py | 14 +-- ai21/types.py | 30 +++++++ ai21/utils/__init__.py | 0 ai21/utils/typing.py | 22 +++++ poetry.lock | 24 ++--- pyproject.toml | 1 + tests/integration_tests/skip_helpers.py | 2 +- .../clients/studio/resources/conftest.py | 18 +--- .../studio/resources/test_studio_resources.py | 2 + tests/unittests/models/__init__.py | 0 tests/unittests/models/test_serialization.py | 6 ++ 16 files changed, 209 insertions(+), 148 deletions(-) create mode 100644 ai21/types.py create mode 100644 ai21/utils/__init__.py create mode 100644 ai21/utils/typing.py create mode 100644 tests/unittests/models/__init__.py create mode 100644 tests/unittests/models/test_serialization.py diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index 11ce8f9d..7353a276 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -1,7 +1,9 @@ -from typing import Optional, List +from typing import List from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource from ai21.models import Penalty, CompletionsResponse +from ai21.types import NotGiven, NOT_GIVEN +from ai21.utils.typing import remove_not_given class BedrockCompletion(BedrockResource): @@ -9,37 +11,33 @@ def create( self, prompt: str, *, - max_tokens: Optional[int] = None, - num_results: Optional[int] = 1, - min_tokens: Optional[int] = 0, - temperature: Optional[float] = 0.7, - top_p: Optional[int] = 1, - top_k_return: Optional[int] = 0, - stop_sequences: Optional[List[str]] = None, - frequency_penalty: Optional[Penalty] = None, - presence_penalty: Optional[Penalty] = None, - count_penalty: Optional[Penalty] = None, + max_tokens: int | NotGiven = NOT_GIVEN, + num_results: int | NotGiven = NOT_GIVEN, + min_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + top_k_return: int | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + frequency_penalty: Penalty | NotGiven = NOT_GIVEN, + presence_penalty: Penalty | NotGiven = NOT_GIVEN, + count_penalty: Penalty | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: - body = { - "prompt": prompt, - "maxTokens": max_tokens, - "numResults": num_results, - "minTokens": min_tokens, - "temperature": temperature, - "topP": top_p, - "topKReturn": top_k_return, - "stopSequences": stop_sequences or [], - } - - if frequency_penalty is not None: - body["frequencyPenalty"] = frequency_penalty.to_dict() - - if presence_penalty is not None: - body["presencePenalty"] = presence_penalty.to_dict() - - if count_penalty is not None: - body["countPenalty"] = count_penalty.to_dict() + body = remove_not_given( + { + "prompt": prompt, + "maxTokens": max_tokens, + "numResults": num_results, + "minTokens": min_tokens, + "temperature": temperature, + "topP": top_p, + "topKReturn": top_k_return, + "stopSequences": stop_sequences or [], + "frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty, + "presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty, + "countPenalty": count_penalty.to_dict() if count_penalty else count_penalty, + } + ) model_id = kwargs.get("model_id", self._model_id) diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index abc338f8..50ce63de 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -1,7 +1,11 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Optional, List, Dict, Any +from typing import List, Dict, Any from ai21.models import Penalty, CompletionsResponse +from ai21.types import NOT_GIVEN, NotGiven +from ai21.utils.typing import remove_not_given class Completion(ABC): @@ -13,18 +17,18 @@ def create( model: str, prompt: str, *, - max_tokens: int = 64, - num_results: int = 1, - min_tokens=0, - temperature=0.7, - top_p=1, - top_k_return=0, - custom_model: Optional[str] = None, - stop_sequences: Optional[List[str]] = (), - frequency_penalty: Optional[Penalty] = None, - presence_penalty: Optional[Penalty] = None, - count_penalty: Optional[Penalty] = None, - epoch: Optional[int] = None, + max_tokens: int | NotGiven = NOT_GIVEN, + num_results: int | NotGiven = NOT_GIVEN, + min_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NOT_GIVEN = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + top_k_return: int | NotGiven = NOT_GIVEN, + custom_model: str | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + frequency_penalty: Penalty | NotGiven = NOT_GIVEN, + presence_penalty: Penalty | NotGiven = NOT_GIVEN, + count_penalty: Penalty | NotGiven = NOT_GIVEN, + epoch: int | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: """ @@ -54,32 +58,34 @@ def _create_body( self, model: str, prompt: str, - max_tokens: Optional[int], - num_results: Optional[int], - min_tokens: Optional[int], - temperature: Optional[float], - top_p: Optional[int], - top_k_return: Optional[int], - custom_model: Optional[str], - stop_sequences: Optional[List[str]], - frequency_penalty: Optional[Penalty], - presence_penalty: Optional[Penalty], - count_penalty: Optional[Penalty], - epoch: Optional[int], + max_tokens: int | NotGiven, + num_results: int | NotGiven, + min_tokens: int | NotGiven, + temperature: float | NotGiven, + top_p: float | NotGiven, + top_k_return: int | NotGiven, + custom_model: str | NotGiven, + stop_sequences: List[str] | NotGiven, + frequency_penalty: Penalty | NotGiven, + presence_penalty: Penalty | NotGiven, + count_penalty: Penalty | NotGiven, + epoch: int | NotGiven, ): - return { - "model": model, - "customModel": custom_model, - "prompt": prompt, - "maxTokens": max_tokens, - "numResults": num_results, - "minTokens": min_tokens, - "temperature": temperature, - "topP": top_p, - "topKReturn": top_k_return, - "stopSequences": stop_sequences or [], - "frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(), - "presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(), - "countPenalty": None if count_penalty is None else count_penalty.to_dict(), - "epoch": epoch, - } + return remove_not_given( + { + "model": model, + "customModel": custom_model, + "prompt": prompt, + "maxTokens": max_tokens, + "numResults": num_results, + "minTokens": min_tokens, + "temperature": temperature, + "topP": top_p, + "topKReturn": top_k_return, + "stopSequences": stop_sequences, + "frequencyPenalty": NOT_GIVEN if frequency_penalty is NOT_GIVEN else frequency_penalty.to_dict(), + "presencePenalty": NOT_GIVEN if presence_penalty is NOT_GIVEN else presence_penalty.to_dict(), + "countPenalty": NOT_GIVEN if count_penalty is NOT_GIVEN else count_penalty.to_dict(), + "epoch": epoch, + } + ) diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index 8da0fa8a..06a8166a 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -1,7 +1,9 @@ -from typing import Optional, List +from typing import List from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource from ai21.models import Penalty, CompletionsResponse +from ai21.types import NotGiven, NOT_GIVEN +from ai21.utils.typing import remove_not_given class SageMakerCompletion(SageMakerResource): @@ -9,37 +11,33 @@ def create( self, prompt: str, *, - max_tokens: Optional[int] = None, - num_results: Optional[int] = 1, - min_tokens: Optional[int] = 0, - temperature: Optional[float] = 0.7, - top_p: Optional[int] = 1, - top_k_return: Optional[int] = 0, - stop_sequences: Optional[List[str]] = None, - frequency_penalty: Optional[Penalty] = None, - presence_penalty: Optional[Penalty] = None, - count_penalty: Optional[Penalty] = None, + max_tokens: int | NotGiven = NOT_GIVEN, + num_results: int | NotGiven = NOT_GIVEN, + min_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + top_k_return: int | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + frequency_penalty: Penalty | NotGiven = NOT_GIVEN, + presence_penalty: Penalty | NotGiven = NOT_GIVEN, + count_penalty: Penalty | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: - body = { - "prompt": prompt, - "maxTokens": max_tokens, - "numResults": num_results, - "minTokens": min_tokens, - "temperature": temperature, - "topP": top_p, - "topKReturn": top_k_return, - "stopSequences": stop_sequences or [], - } - - if frequency_penalty is not None: - body["frequencyPenalty"] = frequency_penalty.to_dict() - - if presence_penalty is not None: - body["presencePenalty"] = presence_penalty.to_dict() - - if count_penalty is not None: - body["countPenalty"] = count_penalty.to_dict() + body = remove_not_given( + { + "prompt": prompt, + "maxTokens": max_tokens, + "numResults": num_results, + "minTokens": min_tokens, + "temperature": temperature, + "topP": top_p, + "topKReturn": top_k_return, + "stopSequences": stop_sequences or [], + "frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty, + "presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty, + "countPenalty": count_penalty.to_dict() if count_penalty else count_penalty, + } + ) raw_response = self._invoke(body) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 3b2cfc77..513b4d1a 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -1,8 +1,11 @@ -from typing import Optional, List +from __future__ import annotations + +from typing import List from ai21.clients.common.completion_base import Completion from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import Penalty, CompletionsResponse +from ai21.types import NOT_GIVEN, NotGiven class StudioCompletion(StudioResource, Completion): @@ -11,23 +14,23 @@ def create( model: str, prompt: str, *, - max_tokens: Optional[int] = None, - num_results: Optional[int] = 1, - min_tokens: Optional[int] = 0, - temperature: Optional[float] = 0.7, - top_p: Optional[float] = 1, - top_k_return: Optional[int] = 0, - custom_model: Optional[str] = None, - stop_sequences: Optional[List[str]] = None, - frequency_penalty: Optional[Penalty] = None, - presence_penalty: Optional[Penalty] = None, - count_penalty: Optional[Penalty] = None, - epoch: Optional[int] = None, + max_tokens: int | NotGiven = NOT_GIVEN, + num_results: int | NotGiven = NOT_GIVEN, + min_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + top_k_return: int | NotGiven = NOT_GIVEN, + custom_model: str | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + frequency_penalty: Penalty | NotGiven = NOT_GIVEN, + presence_penalty: Penalty | NotGiven = NOT_GIVEN, + count_penalty: Penalty | NotGiven = NOT_GIVEN, + epoch: int | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: url = f"{self._client.get_base_url()}/{model}" - if custom_model is not None: + if custom_model: url = f"{url}/{custom_model}" url = f"{url}/{self._module_name}" diff --git a/ai21/models/ai21_base_model_mixin.py b/ai21/models/ai21_base_model_mixin.py index f3e7a2c9..af996fa5 100644 --- a/ai21/models/ai21_base_model_mixin.py +++ b/ai21/models/ai21_base_model_mixin.py @@ -1,7 +1,10 @@ from dataclasses_json import LetterCase, DataClassJsonMixin +from ai21.utils.typing import is_not_given + class AI21BaseModelMixin(DataClassJsonMixin): dataclass_json_config = { "letter_case": LetterCase.CAMEL, + "exclude": is_not_given, } diff --git a/ai21/models/penalty.py b/ai21/models/penalty.py index 74c69ff2..91aa4e03 100644 --- a/ai21/models/penalty.py +++ b/ai21/models/penalty.py @@ -1,14 +1,16 @@ +from __future__ import annotations + from dataclasses import dataclass -from typing import Optional +from ai21.types import NOT_GIVEN, NotGiven from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin @dataclass class Penalty(AI21BaseModelMixin): scale: float - apply_to_whitespaces: Optional[bool] = None - apply_to_punctuation: Optional[bool] = None - apply_to_numbers: Optional[bool] = None - apply_to_stopwords: Optional[bool] = None - apply_to_emojis: Optional[bool] = None + apply_to_whitespaces: bool | NotGiven = NOT_GIVEN + apply_to_punctuation: bool | NotGiven = NOT_GIVEN + apply_to_numbers: bool | NotGiven = NOT_GIVEN + apply_to_stopwords: bool | NotGiven = NOT_GIVEN + apply_to_emojis: bool | NotGiven = NOT_GIVEN diff --git a/ai21/types.py b/ai21/types.py new file mode 100644 index 00000000..137c938d --- /dev/null +++ b/ai21/types.py @@ -0,0 +1,30 @@ +from typing_extensions import Literal + + +# Sentinel class used until PEP 0661 is accepted +class NotGiven: + """ + A sentinel singleton class used to distinguish omitted keyword arguments + from those passed in with the value None (which may have different behavior). + + For example: + + ```py + def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: + ... + + + get(timeout=1) # 1s timeout + get(timeout=None) # No timeout + get() # Default timeout behavior, which may not be statically known at the method definition. + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + def __repr__(self) -> str: + return "NOT_GIVEN" + + +NOT_GIVEN = NotGiven() diff --git a/ai21/utils/__init__.py b/ai21/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/utils/typing.py b/ai21/utils/typing.py new file mode 100644 index 00000000..ae77329b --- /dev/null +++ b/ai21/utils/typing.py @@ -0,0 +1,22 @@ +from typing import Any, Dict + +from ai21.types import NotGiven + + +def is_not_given(value: Any) -> bool: + return isinstance(value, NotGiven) + + +def remove_not_given(body: Dict[str, Any]) -> Dict[str, Any]: + return {k: v for k, v in body.items() if not is_not_given(v)} + + +def to_camel_case(snake_str: str) -> str: + return "".join(x.capitalize() for x in snake_str.lower().split("_")) + + +def to_lower_camel_case(snake_str: str) -> str: + # We capitalize the first letter of each component except the first one + # with the 'capitalize' method and join them together. + camel_string = to_camel_case(snake_str) + return snake_str[0].lower() + camel_string[1:] diff --git a/poetry.lock b/poetry.lock index d1f0c968..a98383f9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1009,24 +1009,24 @@ python-versions = ">=3.6" files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d92f81886165cb14d7b067ef37e142256f1c6a90a65cd156b063a43da1708cfd"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:aa2267c6a303eb483de8d02db2871afb5c5fc15618d894300b88958f729ad74f"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b5edda50e5e9e15e54a6a8a0070302b00c518a9d32accc2346ad6c984aacd279"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:1707814f0d9791df063f8c19bb51b0d1278b8e9a2353abbb676c2f685dee6afe"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:7048c338b6c86627afb27faecf418768acb6331fc24cfa56c93e8c9780f815fa"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:1dc67314e7e1086c9fdf2680b7b6c2be1c0d8e3a8279f2e993ca2a7545fecf62"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"}, @@ -1034,7 +1034,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:77159f5d5b5c14f7c34073862a6b7d34944075d9f93e681638f6d753606c6ce6"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3fcc54cb0c8b811ff66082de1680b4b14cf8a81dce0d4fbf665c2265a81e07a1"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"}, @@ -1042,7 +1042,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:305889baa4043a09e5b76f8e2a51d4ffba44259f6b4c72dec8ca56207d9c6fe1"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:665f58bfd29b167039f714c6998178d27ccd83984084c286110ef26b230f259f"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"}, @@ -1050,7 +1050,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:a1a45e0bb052edf6a1d3a93baef85319733a888363938e1fc9924cb00c8df24c"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:9eb5dee2772b0f704ca2e45b1713e4e5198c18f515b52743576d196348f374d3"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"}, @@ -1251,13 +1251,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.7.1" -description = "Backported and Experimental Type Hints for Python 3.7+" +version = "4.9.0" +description = "Backported and Experimental Type Hints for Python 3.8+" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, - {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, + {file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"}, + {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, ] [[package]] @@ -1329,4 +1329,4 @@ aws = ["boto3"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "39ea6a4fd93efce593b30be52de954f1d6ab4c2d39745a9541067a5af5f37a21" +content-hash = "b0a25fee53075252fd9221e0754103c67291afc63764ffd62f33d36d83453ce3" diff --git a/pyproject.toml b/pyproject.toml index 621fd510..d837d2a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ requests = "^2.31.0" ai21-tokenizer = "^0.3.9" boto3 = { version = "^1.28.82", optional = true } dataclasses-json = "^0.6.3" +typing-extensions = "^4.9.0" [tool.poetry.group.dev.dependencies] diff --git a/tests/integration_tests/skip_helpers.py b/tests/integration_tests/skip_helpers.py index 37b703ff..269262be 100644 --- a/tests/integration_tests/skip_helpers.py +++ b/tests/integration_tests/skip_helpers.py @@ -2,7 +2,7 @@ def should_skip_bedrock_integration_tests() -> bool: - return os.getenv("AWS_ACCESS_KEY_ID") is None and os.getenv("AWS_SECRET_ACCESS_KEY") is None + return os.getenv("AWS_ACCESS_KEY_ID") is None or os.getenv("AWS_SECRET_ACCESS_KEY") is None def should_skip_studio_integration_tests() -> bool: diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 00e2088d..5a7f3f2c 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -44,6 +44,7 @@ SegmentSummary, ) from ai21.models.responses.segmentation_response import Segment +from ai21.utils.typing import to_lower_camel_case @pytest.fixture @@ -109,29 +110,18 @@ def get_studio_chat(): ) -def get_studio_completion(): +def get_studio_completion(**kwargs): _DUMMY_MODEL = "dummy-completion-model" _DUMMY_PROMPT = "dummy-prompt" return ( StudioCompletion, - {"model": _DUMMY_MODEL, "prompt": _DUMMY_PROMPT}, + {"model": _DUMMY_MODEL, "prompt": _DUMMY_PROMPT, **kwargs}, f"{_DUMMY_MODEL}/complete", { "model": _DUMMY_MODEL, "prompt": _DUMMY_PROMPT, - "temperature": 0.7, - "maxTokens": None, - "minTokens": 0, - "epoch": None, - "numResults": 1, - "topP": 1, - "customModel": None, - "topKReturn": 0, - "stopSequences": [], - "frequencyPenalty": None, - "presencePenalty": None, - "countPenalty": None, + **{to_lower_camel_case(k): v for k, v in kwargs.items()}, }, CompletionsResponse( id="some-id", diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index eac6f274..96fdc154 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -32,6 +32,7 @@ class TestStudioResources: "studio_answer", "studio_chat", "studio_completion", + "studio_completion_with_extra_args", "studio_embed", "studio_gec", "studio_improvements", @@ -45,6 +46,7 @@ class TestStudioResources: (get_studio_answer()), (get_studio_chat()), (get_studio_completion()), + (get_studio_completion(temperature=0.5, max_tokens=50)), (get_studio_embed()), (get_studio_gec()), (get_studio_improvements()), diff --git a/tests/unittests/models/__init__.py b/tests/unittests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/models/test_serialization.py b/tests/unittests/models/test_serialization.py new file mode 100644 index 00000000..d854dd68 --- /dev/null +++ b/tests/unittests/models/test_serialization.py @@ -0,0 +1,6 @@ +from ai21.models import Penalty + + +def test_penalty__to_dict__when_has_not_given_fields__should_filter_them_out(): + penalty = Penalty(scale=0.5, apply_to_whitespaces=True) + assert penalty.to_dict() == {"scale": 0.5, "applyToWhitespaces": True} From 6eb124d6b8cbdcb452499aada9a545430b3055c8 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Fri, 23 Feb 2024 18:02:27 +0200 Subject: [PATCH 2/6] fix: test when penalty is not passed --- .../clients/bedrock/test_completion.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/tests/integration_tests/clients/bedrock/test_completion.py b/tests/integration_tests/clients/bedrock/test_completion.py index d581241d..cfd7038c 100644 --- a/tests/integration_tests/clients/bedrock/test_completion.py +++ b/tests/integration_tests/clients/bedrock/test_completion.py @@ -50,17 +50,29 @@ def test_completion__when_no_penalties__should_return_response( frequency_penalty: Optional[Penalty], presence_penalty: Optional[Penalty], count_penalty: Optional[Penalty] ): - client = AI21BedrockClient(model_id=BedrockModelID.J2_MID_V1) - response = client.completion.create( - prompt=_PROMPT, - max_tokens=64, - temperature=0, - top_p=1, - top_k_return=0, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - count_penalty=count_penalty, - ) + client = AI21BedrockClient() + + if frequency_penalty is None and presence_penalty is None and count_penalty is None: + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model_id=BedrockModelID.J2_MID_V1, + temperature=0, + top_p=1, + top_k_return=0, + ) + else: + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model_id=BedrockModelID.J2_MID_V1, + temperature=0, + top_p=1, + top_k_return=0, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + count_penalty=count_penalty, + ) assert response.prompt.text == _PROMPT assert len(response.completions) == 1 From a44d67473e47ce1cbd6ff6b7cd11c4c9d986ed0f Mon Sep 17 00:00:00 2001 From: etang Date: Sat, 24 Feb 2024 19:19:16 +0200 Subject: [PATCH 3/6] fix: fix import, make more variants of penalty --- .../clients/bedrock/test_completion.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/integration_tests/clients/bedrock/test_completion.py b/tests/integration_tests/clients/bedrock/test_completion.py index cfd7038c..b7a04af2 100644 --- a/tests/integration_tests/clients/bedrock/test_completion.py +++ b/tests/integration_tests/clients/bedrock/test_completion.py @@ -20,21 +20,10 @@ argvalues=[ (None, None, None), ( + Penalty(scale=0.5), Penalty( scale=0.5, apply_to_emojis=True, - apply_to_numbers=True, - apply_to_stopwords=True, - apply_to_punctuation=True, - apply_to_whitespaces=True, - ), - Penalty( - scale=0.5, - apply_to_emojis=True, - apply_to_numbers=True, - apply_to_stopwords=True, - apply_to_punctuation=True, - apply_to_whitespaces=True, ), Penalty( scale=0.5, @@ -76,6 +65,7 @@ def test_completion__when_no_penalties__should_return_response( assert response.prompt.text == _PROMPT assert len(response.completions) == 1 + # Check the results aren't all the same assert len([completion.data.text for completion in response.completions]) == 1 for completion in response.completions: From 652ddc28b36cae640aca116d0045115a1cb81d73 Mon Sep 17 00:00:00 2001 From: etang Date: Sat, 24 Feb 2024 19:52:58 +0200 Subject: [PATCH 4/6] refactor: completion test refactor --- .../bedrock/resources/bedrock_completion.py | 2 + .../clients/bedrock/test_completion.py | 38 +++++++++---------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index 7353a276..a5fbf5b2 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import List from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource diff --git a/tests/integration_tests/clients/bedrock/test_completion.py b/tests/integration_tests/clients/bedrock/test_completion.py index b7a04af2..e5dc3cdd 100644 --- a/tests/integration_tests/clients/bedrock/test_completion.py +++ b/tests/integration_tests/clients/bedrock/test_completion.py @@ -40,28 +40,24 @@ def test_completion__when_no_penalties__should_return_response( frequency_penalty: Optional[Penalty], presence_penalty: Optional[Penalty], count_penalty: Optional[Penalty] ): client = AI21BedrockClient() + completion_args = dict( + prompt=_PROMPT, + max_tokens=64, + model_id=BedrockModelID.J2_MID_V1, + temperature=0, + top_p=1, + top_k_return=0, + ) - if frequency_penalty is None and presence_penalty is None and count_penalty is None: - response = client.completion.create( - prompt=_PROMPT, - max_tokens=64, - model_id=BedrockModelID.J2_MID_V1, - temperature=0, - top_p=1, - top_k_return=0, - ) - else: - response = client.completion.create( - prompt=_PROMPT, - max_tokens=64, - model_id=BedrockModelID.J2_MID_V1, - temperature=0, - top_p=1, - top_k_return=0, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - count_penalty=count_penalty, - ) + for arg_name, penalty in [ + ("frequency_penalty", frequency_penalty), + ("presence_penalty", presence_penalty), + ("count_penalty", count_penalty), + ]: + if penalty: + completion_args[arg_name] = penalty + + response = client.completion.create(**completion_args) assert response.prompt.text == _PROMPT assert len(response.completions) == 1 From 4777b4696beba3feae8cce73bee04b1a450e4bc6 Mon Sep 17 00:00:00 2001 From: etang Date: Sun, 25 Feb 2024 12:06:08 +0200 Subject: [PATCH 5/6] fix: rename endpoints --- examples/sagemaker/completion.py | 2 +- examples/sagemaker/summarization.py | 2 +- tests/integration_tests/clients/test_sagemaker.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/sagemaker/completion.py b/examples/sagemaker/completion.py index 8bd6d254..2baa8634 100644 --- a/examples/sagemaker/completion.py +++ b/examples/sagemaker/completion.py @@ -31,7 +31,7 @@ "User: Hi, I have a question for you" ) -client = AI21SageMakerClient(endpoint_name="j2-quantization-mid-reach-dev-cve-version-12-202313") +client = AI21SageMakerClient(endpoint_name="sm_endpoint_name") response = client.completion.create(prompt=prompt, max_tokens=2) print(response) diff --git a/examples/sagemaker/summarization.py b/examples/sagemaker/summarization.py index 4888180d..cea5620a 100644 --- a/examples/sagemaker/summarization.py +++ b/examples/sagemaker/summarization.py @@ -1,6 +1,6 @@ from ai21 import AI21SageMakerClient -client = AI21SageMakerClient(endpoint_name="j2-quantization-mid-reach-dev-cve-version-12-202313") +client = AI21SageMakerClient(endpoint_name="sm_endpoint_name") response = client.summarize.create( source="Holland is a geographical region[2] and former province on the western coast of the Netherlands.[2]" " From the 10th to the 16th century, " diff --git a/tests/integration_tests/clients/test_sagemaker.py b/tests/integration_tests/clients/test_sagemaker.py index 263760ef..631de64a 100644 --- a/tests/integration_tests/clients/test_sagemaker.py +++ b/tests/integration_tests/clients/test_sagemaker.py @@ -8,7 +8,7 @@ SAGEMAKER_PATH = Path(__file__).parent.parent.parent.parent / "examples" / SAGEMAKER_DIR -@pytest.mark.skip(reason="SageMaker integration tests need endpoints to be running") +# @pytest.mark.skip(reason="SageMaker integration tests need endpoints to be running") @pytest.mark.parametrize( argnames=["test_file_name"], argvalues=[ From ccc9b8ce271a0cff7ab7d0c38773649005834d00 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 25 Feb 2024 14:41:08 +0200 Subject: [PATCH 6/6] fix: uncomment skip --- tests/integration_tests/clients/test_sagemaker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/clients/test_sagemaker.py b/tests/integration_tests/clients/test_sagemaker.py index 631de64a..263760ef 100644 --- a/tests/integration_tests/clients/test_sagemaker.py +++ b/tests/integration_tests/clients/test_sagemaker.py @@ -8,7 +8,7 @@ SAGEMAKER_PATH = Path(__file__).parent.parent.parent.parent / "examples" / SAGEMAKER_DIR -# @pytest.mark.skip(reason="SageMaker integration tests need endpoints to be running") +@pytest.mark.skip(reason="SageMaker integration tests need endpoints to be running") @pytest.mark.parametrize( argnames=["test_file_name"], argvalues=[