From a7be77551de89b22e6db4917cbdbe2f3fc2868d1 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 09:36:16 +0200 Subject: [PATCH 01/36] test: get_tokenizer tests --- ai21/tokenizers/factory.py | 4 +-- tests/unittests/test_dummy.py | 2 -- tests/unittests/tokenizers/__init__.py | 0 .../tokenizers/test_ai21_tokenizer.py | 25 +++++++++++++++++++ 4 files changed, 27 insertions(+), 4 deletions(-) delete mode 100644 tests/unittests/test_dummy.py create mode 100644 tests/unittests/tokenizers/__init__.py create mode 100644 tests/unittests/tokenizers/test_ai21_tokenizer.py diff --git a/ai21/tokenizers/factory.py b/ai21/tokenizers/factory.py index fd229231..cd728f77 100644 --- a/ai21/tokenizers/factory.py +++ b/ai21/tokenizers/factory.py @@ -16,6 +16,6 @@ def get_tokenizer() -> AI21Tokenizer: global _cached_tokenizer if _cached_tokenizer is None: - _cached_tokenizer = Tokenizer.get_tokenizer() + _cached_tokenizer = AI21Tokenizer(Tokenizer.get_tokenizer()) - return AI21Tokenizer(_cached_tokenizer) + return _cached_tokenizer diff --git a/tests/unittests/test_dummy.py b/tests/unittests/test_dummy.py deleted file mode 100644 index 39b433bc..00000000 --- a/tests/unittests/test_dummy.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_assert(): - assert True diff --git a/tests/unittests/tokenizers/__init__.py b/tests/unittests/tokenizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/tokenizers/test_ai21_tokenizer.py b/tests/unittests/tokenizers/test_ai21_tokenizer.py new file mode 100644 index 00000000..33f89d9e --- /dev/null +++ b/tests/unittests/tokenizers/test_ai21_tokenizer.py @@ -0,0 +1,25 @@ +from ai21.tokenizers.factory import get_tokenizer + + +class TestAI21Tokenizer: + def test__count_tokens__should_return_number_of_tokens(self): + expected_number_of_tokens = 8 + tokenizer = get_tokenizer() + + actual_number_of_tokens = tokenizer.count_tokens("Text to Tokenize - Hello world!") + + assert actual_number_of_tokens == expected_number_of_tokens + + def test__tokenize__should_return_list_of_tokens(self): + expected_tokens = ["▁Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"] + tokenizer = get_tokenizer() + + actual_tokens = tokenizer.tokenize("Text to Tokenize - Hello world!") + + assert actual_tokens == expected_tokens + + def test__tokenizer__should_be_singleton__when_called_twice(self): + tokenizer1 = get_tokenizer() + tokenizer2 = get_tokenizer() + + assert tokenizer1 is tokenizer2 From 22fee638ff21be59c6fc6db45fe6844bc193507c Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 15:57:35 +0200 Subject: [PATCH 02/36] fix: cases --- ai21/clients/studio/resources/studio_chat.py | 2 +- ai21/clients/studio/resources/studio_completion.py | 2 +- ai21/resources/bases/chat_base.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index f1dab12b..8fe1bca4 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -39,6 +39,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, ) - url = f"{self._client.get_base_url()}/{model}/{self._module_name}" + url = f"{self._client.get_base_url()}/{model}/{self._MODULE_NAME}" response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 84e179b3..48f85fa7 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -15,7 +15,7 @@ def create( num_results: Optional[int] = 1, min_tokens: Optional[int] = 0, temperature: Optional[float] = 0.7, - top_p: Optional[int] = 1, + top_p: Optional[float] = 1, top_k_return: Optional[int] = 0, custom_model: Optional[str] = None, experimental_mode: bool = False, diff --git a/ai21/resources/bases/chat_base.py b/ai21/resources/bases/chat_base.py index f85270ee..e2a67c0d 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/resources/bases/chat_base.py @@ -11,7 +11,7 @@ class Message: class Chat(ABC): - _module_name = "chat" + _MODULE_NAME = "chat" @abstractmethod def create( From 68fd37e9640e6bb1eab9db92e74142d437e58700 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 15:58:42 +0200 Subject: [PATCH 03/36] test: Added some unittests to resources --- ai21/services/sagemaker.py | 4 +- poetry.lock | 19 ++- pyproject.toml | 1 + tests/unittests/clients/__init__.py | 0 tests/unittests/clients/studio/__init__.py | 0 .../clients/studio/resources/__init__.py | 0 .../clients/studio/resources/conftest.py | 122 ++++++++++++++++++ .../studio/resources/test_studio_resources.py | 80 ++++++++++++ tests/unittests/services/__init__.py | 0 tests/unittests/services/test_sagemaker.py | 28 ++++ 10 files changed, 251 insertions(+), 3 deletions(-) create mode 100644 tests/unittests/clients/__init__.py create mode 100644 tests/unittests/clients/studio/__init__.py create mode 100644 tests/unittests/clients/studio/resources/__init__.py create mode 100644 tests/unittests/clients/studio/resources/conftest.py create mode 100644 tests/unittests/clients/studio/resources/test_studio_resources.py create mode 100644 tests/unittests/services/__init__.py create mode 100644 tests/unittests/services/test_sagemaker.py diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index fa811541..2d9c6469 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -20,7 +20,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE client = AI21StudioClient() - response = client.execute_http_request( + client.execute_http_request( method="POST", url=f"{client.get_base_url()}/{_GET_ARN_ENDPOINT}", params={ @@ -30,7 +30,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE }, ) - arn = response["arn"] + arn = ["arn"] if not arn: raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version) diff --git a/poetry.lock b/poetry.lock index 334d859e..cd17ee01 100644 --- a/poetry.lock +++ b/poetry.lock @@ -848,6 +848,23 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-mock" +version = "3.12.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, + {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, +] + +[package.dependencies] +pytest = ">=5.0" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -1312,4 +1329,4 @@ aws = ["boto3"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "71ce6369e72538e571ac954b5ebb4e66fa79c1752aa61af336144df577078cc4" +content-hash = "39ea6a4fd93efce593b30be52de954f1d6ab4c2d39745a9541067a5af5f37a21" diff --git a/pyproject.toml b/pyproject.toml index 120b38ae..060d165b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ safety = "*" ruff = "*" python-semantic-release = "^8.5.0" pytest = "^7.4.3" +pytest-mock = "^3.12.0" [tool.poetry.extras] AWS = ["boto3"] diff --git a/tests/unittests/clients/__init__.py b/tests/unittests/clients/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/__init__.py b/tests/unittests/clients/studio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/resources/__init__.py b/tests/unittests/clients/studio/resources/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py new file mode 100644 index 00000000..c4b38e0f --- /dev/null +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -0,0 +1,122 @@ +import pytest +from pytest_mock import MockerFixture + +from ai21 import AnswerResponse, ChatResponse, CompletionsResponse +from ai21.ai21_studio_client import AI21StudioClient +from ai21.clients.studio.resources.studio_answer import StudioAnswer +from ai21.clients.studio.resources.studio_chat import StudioChat +from ai21.clients.studio.resources.studio_completion import StudioCompletion +from ai21.resources.responses.chat_response import ChatOutput, FinishReason +from ai21.resources.responses.completion_response import Prompt, Completion, CompletionData, CompletionFinishReason + + +@pytest.fixture +def mock_ai21_studio_client(mocker: MockerFixture) -> AI21StudioClient: + return mocker.MagicMock(spec=AI21StudioClient) + + +def get_studio_answer(): + _DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" + _DUMMY_QUESTION = "What is the answer?" + + return ( + StudioAnswer, + {"context": _DUMMY_CONTEXT, "question": _DUMMY_QUESTION}, + "answer", + { + "answerLength": None, + "context": _DUMMY_CONTEXT, + "mode": None, + "question": _DUMMY_QUESTION, + }, + AnswerResponse(id="some-id", answer_in_context=True, answer="42"), + ) + + +def get_studio_chat(): + _DUMMY_MODEL = "dummy-chat-model" + _DUMMY_MESSAGES = [ + { + "text": "Hello, I need help with a signup process.", + "role": "user", + "name": "Alice", + }, + { + "text": "Hi Alice, I can help you with that. What seems to be the problem?", + "role": "assistant", + "name": "Bob", + }, + ] + _DUMMY_SYSTEM = "You're a support engineer in a SaaS company" + + return ( + StudioChat, + {"model": _DUMMY_MODEL, "messages": _DUMMY_MESSAGES, "system": _DUMMY_SYSTEM}, + f"{_DUMMY_MODEL}/chat", + { + "model": _DUMMY_MODEL, + "system": _DUMMY_SYSTEM, + "messages": _DUMMY_MESSAGES, + "temperature": 0.7, + "maxTokens": 300, + "minTokens": 0, + "numResults": 1, + "topP": 1.0, + "topKReturn": 0, + "stopSequences": None, + "frequencyPenalty": None, + "presencePenalty": None, + "countPenalty": None, + }, + ChatResponse( + outputs=[ + ChatOutput( + text="Hello, I need help with a signup process.", + role="user", + finish_reason=FinishReason(reason="dummy_reason", length=1, sequence="1"), + ) + ] + ), + ) + + +def get_studio_completion(): + _DUMMY_MODEL = "dummy-completion-model" + _DUMMY_PROMPT = "dummy-prompt" + + return ( + StudioCompletion, + {"model": _DUMMY_MODEL, "prompt": _DUMMY_PROMPT}, + f"{_DUMMY_MODEL}/complete", + { + "model": _DUMMY_MODEL, + "prompt": _DUMMY_PROMPT, + "temperature": 0.7, + "maxTokens": None, + "minTokens": 0, + "epoch": None, + "numResults": 1, + "topP": 1, + "customModel": None, + "experimentalModel": False, + "topKReturn": 0, + "stopSequences": [], + "frequencyPenalty": None, + "presencePenalty": None, + "countPenalty": None, + }, + CompletionsResponse( + id="some-id", + completions=[ + Completion( + data=CompletionData(text="dummy-completion", tokens=[]), + finish_reason=CompletionFinishReason(reason="dummy_reason", length=1), + ) + ], + prompt=Prompt(text="dummy-prompt"), + ), + ) + + +def get_studio_custom_model(): + pass diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py new file mode 100644 index 00000000..f9051c35 --- /dev/null +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -0,0 +1,80 @@ +from typing import TypeVar, Callable + +import pytest + +from ai21 import AnswerResponse +from ai21.ai21_studio_client import AI21StudioClient +from ai21.clients.studio.resources.studio_answer import StudioAnswer +from ai21.resources.studio_resource import StudioResource +from tests.unittests.clients.studio.resources.conftest import get_studio_answer, get_studio_chat, get_studio_completion + +_BASE_URL = "https://test.api.ai21.com/studio/v1" +_DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" +_DUMMY_QUESTION = "What is the answer?" + +T = TypeVar("T", bound=StudioResource) + + +class TestStudioResources: + @pytest.mark.parametrize( + ids=[ + "studio_answer", + "studio_chat", + "studio_completion", + ], + argnames=["studio_resource", "function_body", "url_suffix", "expected_body", "expected_response"], + argvalues=[ + (get_studio_answer()), + (get_studio_chat()), + (get_studio_completion()), + ], + ) + def test__create__should_return_answer_response( + self, + studio_resource: Callable[[AI21StudioClient], T], + function_body, + url_suffix: str, + expected_body, + expected_response, + mock_ai21_studio_client: AI21StudioClient, + ): + mock_ai21_studio_client.execute_http_request.return_value = expected_response.to_dict() + mock_ai21_studio_client.get_base_url.return_value = _BASE_URL + + resource = studio_resource(mock_ai21_studio_client) + + actual_response = resource.create( + **function_body, + ) + + assert actual_response == expected_response + mock_ai21_studio_client.execute_http_request.assert_called_with( + method="POST", + url=f"{_BASE_URL}/{url_suffix}", + params=expected_body, + files=None, + ) + + def test__create__when_pass_kwargs__should_not_pass_to_request(self, mock_ai21_studio_client: AI21StudioClient): + expected_answer = AnswerResponse(id="some-id", answer_in_context=True, answer="42") + mock_ai21_studio_client.execute_http_request.return_value = expected_answer.to_dict() + mock_ai21_studio_client.get_base_url.return_value = _BASE_URL + studio_answer = StudioAnswer(mock_ai21_studio_client) + + studio_answer.create( + context=_DUMMY_CONTEXT, + question=_DUMMY_QUESTION, + some_dummy_kwargs="some_dummy_value", + ) + + mock_ai21_studio_client.execute_http_request.assert_called_with( + method="POST", + url=_BASE_URL + "/answer", + params={ + "answerLength": None, + "context": _DUMMY_CONTEXT, + "mode": None, + "question": _DUMMY_QUESTION, + }, + files=None, + ) diff --git a/tests/unittests/services/__init__.py b/tests/unittests/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py new file mode 100644 index 00000000..8f2a4fad --- /dev/null +++ b/tests/unittests/services/test_sagemaker.py @@ -0,0 +1,28 @@ +import pytest +from pytest_mock import MockerFixture + +from ai21 import SageMaker +from ai21.ai21_studio_client import AI21StudioClient +from unittest.mock import patch + + +@pytest.fixture +def mock_ai21_studio_client(mocker: MockerFixture): + return mocker.patch.object( + AI21StudioClient, + "execute_http_request", + return_value={ + "arn": "some-model-package-id1", + "versions": ["1.0.0", "1.0.1"], + }, + ) + + +class TestSageMakerService: + def test__get_model_package_arn__should_return_model_package_arn(self, mocker, mock_ai21_studio_client): + with patch("ai21.ai21_studio_client.AI21StudioClient"): + expected_model_package_arn = "some-model-package-id1" + + actual_model_package_arn = SageMaker.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == expected_model_package_arn From bdac3c4b3c6d6c57588b93b4485754603aae8161 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 16:34:25 +0200 Subject: [PATCH 04/36] fix: rename var --- ai21/http_client.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ai21/http_client.py b/ai21/http_client.py index 7f9b0286..9b75126c 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -56,10 +56,10 @@ def requests_retry_session(session, retries=0): class HttpClient: def __init__(self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None): - self.timeout_sec = timeout_sec if timeout_sec is not None else DEFAULT_TIMEOUT_SEC - self.num_retries = num_retries if num_retries is not None else DEFAULT_NUM_RETRIES - self.headers = headers if headers is not None else {} - self.apply_retry_policy = self.num_retries > 0 + self._timeout_sec = timeout_sec or DEFAULT_TIMEOUT_SEC + self._num_retries = num_retries or DEFAULT_NUM_RETRIES + self._headers = headers or {} + self._apply_retry_policy = self._num_retries > 0 def execute_http_request( self, @@ -70,12 +70,12 @@ def execute_http_request( auth=None, ): session = ( - requests_retry_session(requests.Session(), retries=self.num_retries) - if self.apply_retry_policy + requests_retry_session(requests.Session(), retries=self._num_retries) + if self._apply_retry_policy else requests.Session() ) - timeout = self.timeout_sec - headers = self.headers + timeout = self._timeout_sec + headers = self._headers data = json.dumps(params).encode() logger.info(f"Calling {method} {url} {headers} {data}") try: @@ -112,7 +112,7 @@ def execute_http_request( raise connection_error except RetryError as retry_error: logger.error( - f"Calling {method} {url} failed with RetryError after {self.num_retries} attempts: {retry_error}" + f"Calling {method} {url} failed with RetryError after {self._num_retries} attempts: {retry_error}" ) raise retry_error except Exception as exception: From d1f8be104d8a5f85b2fb4a7a67d26391aa36f934 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 16:52:10 +0200 Subject: [PATCH 05/36] test: Added ai21 studio client tsts --- ai21/ai21_studio_client.py | 8 ++-- ai21/clients/studio/ai21_client.py | 28 +++++++------- tests/test_ai21_studio_client.py | 60 ++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 18 deletions(-) create mode 100644 tests/test_ai21_studio_client.py diff --git a/ai21/ai21_studio_client.py b/ai21/ai21_studio_client.py index 0ceef023..6b6b5543 100644 --- a/ai21/ai21_studio_client.py +++ b/ai21/ai21_studio_client.py @@ -17,6 +17,7 @@ def __init__( timeout_sec: Optional[int] = None, num_retries: Optional[int] = None, organization: Optional[str] = None, + application: Optional[str] = None, via: Optional[str] = None, env_config: _AI21EnvConfig = AI21EnvConfig, ): @@ -32,12 +33,11 @@ def __init__( self._timeout_sec = timeout_sec or self._env_config.timeout_sec self._num_retries = num_retries or self._env_config.num_retries self._organization = organization or self._env_config.organization - self._application = self._env_config.application + self._application = application or self._env_config.application self._via = via headers = self._build_headers(passed_headers=headers) - - self.http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) + self._http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: headers = { @@ -68,7 +68,7 @@ def _build_user_agent(self) -> str: return user_agent def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None): - return self.http_client.execute_http_request(method=method, url=url, params=params, files=files) + return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 025a3bab..7bca5cad 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -35,7 +35,7 @@ def __init__( via: Optional[str] = None, **kwargs, ): - studio_client = AI21StudioClient( + self._studio_client = AI21StudioClient( api_key=api_key, api_host=api_host, headers=headers, @@ -43,19 +43,19 @@ def __init__( num_retries=num_retries, via=via, ) - self.completion = StudioCompletion(studio_client) - self.chat = StudioChat(studio_client) - self.summarize = StudioSummarize(studio_client) - self.embed = StudioEmbed(studio_client) - self.gec = StudioGEC(studio_client) - self.improvements = StudioImprovements(studio_client) - self.paraphrase = StudioParaphrase(studio_client) - self.summarize_by_segment = StudioSummarizeBySegment(studio_client) - self.custom_model = StudioCustomModel(studio_client) - self.dataset = StudioDataset(studio_client) - self.answer = StudioAnswer(studio_client) - self.library = StudioLibrary(studio_client) - self.segmentation = StudioSegmentation(studio_client) + self.completion = StudioCompletion(self._studio_client) + self.chat = StudioChat(self._studio_client) + self.summarize = StudioSummarize(self._studio_client) + self.embed = StudioEmbed(self._studio_client) + self.gec = StudioGEC(self._studio_client) + self.improvements = StudioImprovements(self._studio_client) + self.paraphrase = StudioParaphrase(self._studio_client) + self.summarize_by_segment = StudioSummarizeBySegment(self._studio_client) + self.custom_model = StudioCustomModel(self._studio_client) + self.dataset = StudioDataset(self._studio_client) + self.answer = StudioAnswer(self._studio_client) + self.library = StudioLibrary(self._studio_clienself._studio_clientt) + self.segmentation = StudioSegmentation() def count_tokens(self, text: str) -> int: # We might want to cache the tokenizer instance within the class diff --git a/tests/test_ai21_studio_client.py b/tests/test_ai21_studio_client.py new file mode 100644 index 00000000..aaecfc99 --- /dev/null +++ b/tests/test_ai21_studio_client.py @@ -0,0 +1,60 @@ +from typing import Optional + +import pytest + +from ai21.ai21_studio_client import AI21StudioClient +from ai21.version import VERSION + +_DUMMY_API_KEY = "dummy_key" + + +class TestAI21StudioClient: + @pytest.mark.parametrize( + ids=[ + "when_pass_only_via__should_include_via_in_user_agent", + "when_pass_only_application__should_include_application_in_user_agent", + "when_pass_organization__should_include_organization_in_user_agent", + "when_pass_all_user_agent_relevant_params__should_include_them_in_user_agent", + ], + argnames=["via", "application", "organization", "expected_user_agent"], + argvalues=[ + ("langchain", None, None, f"ai21 studio SDK {VERSION} via: langchain"), + (None, "studio", None, f"ai21 studio SDK {VERSION} application: studio"), + (None, None, "ai21", f"ai21 studio SDK {VERSION} organization: ai21"), + ( + "langchain", + "studio", + "ai21", + f"ai21 studio SDK {VERSION} organization: ai21 application: studio via: langchain", + ), + ], + ) + def test__build_headers__user_agent( + self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str + ): + client = AI21StudioClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) + assert client._http_client._headers["User-Agent"] == expected_user_agent + + def test__build_headers__authorization(self): + client = AI21StudioClient(api_key=_DUMMY_API_KEY) + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + def test__build_headers__when_pass_headers__should_append(self): + client = AI21StudioClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) + assert client._http_client._headers["foo"] == "bar" + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + @pytest.mark.parametrize( + ids=[ + "when_api_host_is_not_set__should_return_default", + "when_api_host_is_set__should_return_set_value", + ], + argnames=["api_host", "expected_api_host"], + argvalues=[ + (None, "https://api.ai21.com/studio/v1"), + ("http://test_host", "http://test_host/studio/v1"), + ], + ) + def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): + client = AI21StudioClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") + assert client.get_base_url() == expected_api_host From 287b4ebe996a4b7fc7a76455e58665a39eca8574 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 17:07:46 +0200 Subject: [PATCH 06/36] fix: rename files --- ...1_studio_client.py => ai21_http_client.py} | 2 +- ai21/clients/studio/ai21_client.py | 30 +++++++++---------- .../studio/resources/studio_library.py | 4 +-- ai21/resources/studio_resource.py | 4 +-- ai21/services/sagemaker.py | 6 ++-- tests/test_ai21_studio_client.py | 10 +++---- .../clients/studio/resources/conftest.py | 6 ++-- .../studio/resources/test_studio_resources.py | 8 ++--- tests/unittests/services/test_sagemaker.py | 4 +-- 9 files changed, 37 insertions(+), 37 deletions(-) rename ai21/{ai21_studio_client.py => ai21_http_client.py} (99%) diff --git a/ai21/ai21_studio_client.py b/ai21/ai21_http_client.py similarity index 99% rename from ai21/ai21_studio_client.py rename to ai21/ai21_http_client.py index 6b6b5543..df845f39 100644 --- a/ai21/ai21_studio_client.py +++ b/ai21/ai21_http_client.py @@ -6,7 +6,7 @@ from ai21.version import VERSION -class AI21StudioClient: +class AI21HTTPClient: def __init__( self, *, diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 7bca5cad..e308bc0a 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -1,6 +1,6 @@ from typing import Optional, Any, Dict -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion @@ -35,7 +35,7 @@ def __init__( via: Optional[str] = None, **kwargs, ): - self._studio_client = AI21StudioClient( + self._http_client = AI21HTTPClient( api_key=api_key, api_host=api_host, headers=headers, @@ -43,19 +43,19 @@ def __init__( num_retries=num_retries, via=via, ) - self.completion = StudioCompletion(self._studio_client) - self.chat = StudioChat(self._studio_client) - self.summarize = StudioSummarize(self._studio_client) - self.embed = StudioEmbed(self._studio_client) - self.gec = StudioGEC(self._studio_client) - self.improvements = StudioImprovements(self._studio_client) - self.paraphrase = StudioParaphrase(self._studio_client) - self.summarize_by_segment = StudioSummarizeBySegment(self._studio_client) - self.custom_model = StudioCustomModel(self._studio_client) - self.dataset = StudioDataset(self._studio_client) - self.answer = StudioAnswer(self._studio_client) - self.library = StudioLibrary(self._studio_clienself._studio_clientt) - self.segmentation = StudioSegmentation() + self.completion = StudioCompletion(self._http_client) + self.chat = StudioChat(self._http_client) + self.summarize = StudioSummarize(self._http_client) + self.embed = StudioEmbed(self._http_client) + self.gec = StudioGEC(self._http_client) + self.improvements = StudioImprovements(self._http_client) + self.paraphrase = StudioParaphrase(self._http_client) + self.summarize_by_segment = StudioSummarizeBySegment(self._http_client) + self.custom_model = StudioCustomModel(self._http_client) + self.dataset = StudioDataset(self._http_client) + self.answer = StudioAnswer(self._http_client) + self.library = StudioLibrary(self._http_client) + self.segmentation = StudioSegmentation(self._http_client) def count_tokens(self, text: str) -> int: # We might want to cache the tokenizer instance within the class diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index ae785a85..42daedbb 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,6 +1,6 @@ from typing import Optional, List -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.resources.responses.file_response import FileResponse from ai21.resources.responses.library_answer_response import LibraryAnswerResponse from ai21.resources.responses.library_search_response import LibrarySearchResponse @@ -10,7 +10,7 @@ class StudioLibrary(StudioResource): _module_name = "library/files" - def __init__(self, client: AI21StudioClient): + def __init__(self, client: AI21HTTPClient): super().__init__(client) self.files = LibraryFiles(client) self.search = LibrarySearch(client) diff --git a/ai21/resources/studio_resource.py b/ai21/resources/studio_resource.py index a467bd94..ee23341c 100644 --- a/ai21/resources/studio_resource.py +++ b/ai21/resources/studio_resource.py @@ -3,11 +3,11 @@ from abc import ABC from typing import Any, Dict, Optional -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient class StudioResource(ABC): - def __init__(self, client: AI21StudioClient): + def __init__(self, client: AI21HTTPClient): self._client = client def _post( diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index 2d9c6469..a4c26d3e 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -1,6 +1,6 @@ from typing import List -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.sagemaker.constants import ( SAGEMAKER_MODEL_PACKAGE_NAMES, ) @@ -18,7 +18,7 @@ class SageMaker: def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21StudioClient() + client = AI21HTTPClient() client.execute_http_request( method="POST", @@ -40,7 +40,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE @classmethod def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21StudioClient() + client = AI21HTTPClient() response = client.execute_http_request( method="POST", diff --git a/tests/test_ai21_studio_client.py b/tests/test_ai21_studio_client.py index aaecfc99..bbbe2ec6 100644 --- a/tests/test_ai21_studio_client.py +++ b/tests/test_ai21_studio_client.py @@ -2,7 +2,7 @@ import pytest -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.version import VERSION _DUMMY_API_KEY = "dummy_key" @@ -32,15 +32,15 @@ class TestAI21StudioClient: def test__build_headers__user_agent( self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str ): - client = AI21StudioClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) assert client._http_client._headers["User-Agent"] == expected_user_agent def test__build_headers__authorization(self): - client = AI21StudioClient(api_key=_DUMMY_API_KEY) + client = AI21HTTPClient(api_key=_DUMMY_API_KEY) assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" def test__build_headers__when_pass_headers__should_append(self): - client = AI21StudioClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) assert client._http_client._headers["foo"] == "bar" assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" @@ -56,5 +56,5 @@ def test__build_headers__when_pass_headers__should_append(self): ], ) def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): - client = AI21StudioClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") assert client.get_base_url() == expected_api_host diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index c4b38e0f..1a921fef 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -2,7 +2,7 @@ from pytest_mock import MockerFixture from ai21 import AnswerResponse, ChatResponse, CompletionsResponse -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion @@ -11,8 +11,8 @@ @pytest.fixture -def mock_ai21_studio_client(mocker: MockerFixture) -> AI21StudioClient: - return mocker.MagicMock(spec=AI21StudioClient) +def mock_ai21_studio_client(mocker: MockerFixture) -> AI21HTTPClient: + return mocker.MagicMock(spec=AI21HTTPClient) def get_studio_answer(): diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index f9051c35..0e4de3af 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -3,7 +3,7 @@ import pytest from ai21 import AnswerResponse -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.resources.studio_resource import StudioResource from tests.unittests.clients.studio.resources.conftest import get_studio_answer, get_studio_chat, get_studio_completion @@ -31,12 +31,12 @@ class TestStudioResources: ) def test__create__should_return_answer_response( self, - studio_resource: Callable[[AI21StudioClient], T], + studio_resource: Callable[[AI21HTTPClient], T], function_body, url_suffix: str, expected_body, expected_response, - mock_ai21_studio_client: AI21StudioClient, + mock_ai21_studio_client: AI21HTTPClient, ): mock_ai21_studio_client.execute_http_request.return_value = expected_response.to_dict() mock_ai21_studio_client.get_base_url.return_value = _BASE_URL @@ -55,7 +55,7 @@ def test__create__should_return_answer_response( files=None, ) - def test__create__when_pass_kwargs__should_not_pass_to_request(self, mock_ai21_studio_client: AI21StudioClient): + def test__create__when_pass_kwargs__should_not_pass_to_request(self, mock_ai21_studio_client: AI21HTTPClient): expected_answer = AnswerResponse(id="some-id", answer_in_context=True, answer="42") mock_ai21_studio_client.execute_http_request.return_value = expected_answer.to_dict() mock_ai21_studio_client.get_base_url.return_value = _BASE_URL diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index 8f2a4fad..531fb5ef 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -2,14 +2,14 @@ from pytest_mock import MockerFixture from ai21 import SageMaker -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from unittest.mock import patch @pytest.fixture def mock_ai21_studio_client(mocker: MockerFixture): return mocker.patch.object( - AI21StudioClient, + AI21HTTPClient, "execute_http_request", return_value={ "arn": "some-model-package-id1", From 8697aeeed1b6ecdcbdc4062e0cbb07f2aa3a7e6b Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 11:55:52 +0200 Subject: [PATCH 07/36] fix: Added types --- ai21/ai21_http_client.py | 12 ++++++++++-- ai21/clients/studio/ai21_client.py | 3 +++ ai21/http_client.py | 7 ++++--- ai21/resources/studio_resource.py | 3 ++- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index df845f39..b58fbb9f 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,3 +1,4 @@ +import io from typing import Optional, Dict, Any from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig @@ -19,6 +20,7 @@ def __init__( organization: Optional[str] = None, application: Optional[str] = None, via: Optional[str] = None, + http_client: Optional[HttpClient] = None, env_config: _AI21EnvConfig = AI21EnvConfig, ): self._env_config = env_config @@ -37,7 +39,7 @@ def __init__( self._via = via headers = self._build_headers(passed_headers=headers) - self._http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) + self._http_client = http_client or HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: headers = { @@ -67,7 +69,13 @@ def _build_user_agent(self) -> str: return user_agent - def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None): + def execute_http_request( + self, + method: str, + url: str, + params: Optional[Dict] = None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, + ): return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) def get_base_url(self) -> str: diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index e308bc0a..5f6cee06 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -14,6 +14,7 @@ from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation from ai21.clients.studio.resources.studio_summarize import StudioSummarize from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment +from ai21.http_client import HttpClient from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer from ai21.tokenizers.factory import get_tokenizer @@ -33,6 +34,7 @@ def __init__( timeout_sec: Optional[float] = None, num_retries: Optional[int] = None, via: Optional[str] = None, + http_client: Optional[HttpClient] = None, **kwargs, ): self._http_client = AI21HTTPClient( @@ -42,6 +44,7 @@ def __init__( timeout_sec=timeout_sec, num_retries=num_retries, via=via, + http_client=http_client, ) self.completion = StudioCompletion(self._http_client) self.chat = StudioChat(self._http_client) diff --git a/ai21/http_client.py b/ai21/http_client.py index 9b75126c..e5d9620a 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,5 +1,6 @@ +import io import json -from typing import Optional, Dict +from typing import Optional, Dict, Callable, Tuple, Union import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -66,8 +67,8 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, - files=None, - auth=None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, + auth: Optional[Union[Tuple, Callable]] = None, ): session = ( requests_retry_session(requests.Session(), retries=self._num_retries) diff --git a/ai21/resources/studio_resource.py b/ai21/resources/studio_resource.py index ee23341c..7752be91 100644 --- a/ai21/resources/studio_resource.py +++ b/ai21/resources/studio_resource.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from abc import ABC from typing import Any, Dict, Optional @@ -14,7 +15,7 @@ def _post( self, url: str, body: Dict[str, Any], - files: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, ) -> Dict[str, Any]: return self._client.execute_http_request( method="POST", From 0b3a1f6f91294114b8569c23d31a054efdb3653e Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 17:30:56 +0200 Subject: [PATCH 08/36] test: added test to http --- ai21/ai21_http_client.py | 7 +++- ai21/http_client.py | 17 ++++++--- tests/conftest.py | 12 ++++++ ...dio_client.py => test_ai21_http_client.py} | 37 +++++++++++++++++++ 4 files changed, 67 insertions(+), 6 deletions(-) create mode 100644 tests/conftest.py rename tests/{test_ai21_studio_client.py => test_ai21_http_client.py} (67%) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index b58fbb9f..7f6fb66d 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,6 +1,8 @@ import io from typing import Optional, Dict, Any +import requests + from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyException from ai21.http_client import HttpClient @@ -75,8 +77,11 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, + session: Optional[requests.Session] = None, ): - return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) + return self._http_client.execute_http_request( + method=method, url=url, params=params, files=files, session=session + ) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/http_client.py b/ai21/http_client.py index e5d9620a..b8d8a9f0 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -69,12 +69,9 @@ def execute_http_request( params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, auth: Optional[Union[Tuple, Callable]] = None, + session: Optional[requests.Session] = None, ): - session = ( - requests_retry_session(requests.Session(), retries=self._num_retries) - if self._apply_retry_policy - else requests.Session() - ) + session = self._init_session(session) timeout = self._timeout_sec headers = self._headers data = json.dumps(params).encode() @@ -125,3 +122,13 @@ def execute_http_request( handle_non_success_response(response.status_code, response.text) return response.json() + + def _init_session(self, session: Optional[requests.Session]) -> requests.Session: + if session is not None: + return session + + return ( + requests_retry_session(requests.Session(), retries=self._num_retries) + if self._apply_retry_policy + else requests.Session() + ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..02e4d467 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import pytest +import requests + + +@pytest.fixture +def dummy_api_host() -> str: + return "http://test_host" + + +@pytest.fixture +def mock_requests_session(mocker) -> requests.Session: + return mocker.Mock(spec=requests.Session) diff --git a/tests/test_ai21_studio_client.py b/tests/test_ai21_http_client.py similarity index 67% rename from tests/test_ai21_studio_client.py rename to tests/test_ai21_http_client.py index bbbe2ec6..eb48b451 100644 --- a/tests/test_ai21_studio_client.py +++ b/tests/test_ai21_http_client.py @@ -1,6 +1,7 @@ from typing import Optional import pytest +import requests from ai21.ai21_http_client import AI21HTTPClient from ai21.version import VERSION @@ -8,6 +9,15 @@ _DUMMY_API_KEY = "dummy_key" +class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + class TestAI21StudioClient: @pytest.mark.parametrize( ids=[ @@ -58,3 +68,30 @@ def test__build_headers__when_pass_headers__should_append(self): def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") assert client.get_base_url() == expected_api_host + + def test__execute_http_request__when_making_request__should_send_appropriate_parameters( + self, + dummy_api_host: str, + mock_requests_session: requests.Session, + ): + method = "GET" + url = "test_url" + params = {"foo": "bar"} + response_json = {"test_key": "test_value"} + headers = { + "Authorization": "Bearer dummy_key", + "Content-Type": "application/json", + "User-Agent": f"ai21 studio SDK {VERSION}", + } + + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") + mock_requests_session.request.return_value = MockResponse(response_json, 200) + response = client.execute_http_request(method=method, url=url, params=params, session=mock_requests_session) + assert response == response_json + mock_requests_session.request.assert_called_once_with( + method, + url, + headers=headers, + timeout=300, + params=params, + ) From c7fa29f4bf96b4449d56faa51fe7637ee1eb51b2 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 18:10:43 +0200 Subject: [PATCH 09/36] fix: removed unnecessary auth param --- ai21/http_client.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/ai21/http_client.py b/ai21/http_client.py index b8d8a9f0..ecf7bd19 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,6 +1,6 @@ import io import json -from typing import Optional, Dict, Callable, Tuple, Union +from typing import Optional, Dict import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -68,7 +68,6 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, - auth: Optional[Union[Tuple, Callable]] = None, session: Optional[requests.Session] = None, ): session = self._init_session(session) @@ -79,8 +78,8 @@ def execute_http_request( try: if method == "GET": response = session.request( - method, - url, + method=method, + url=url, headers=headers, timeout=timeout, params=params, @@ -95,16 +94,15 @@ def execute_http_request( "Content-Type" ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload response = session.request( - method, - url, + method=method, + url=url, headers=headers, data=params, files=files, timeout=timeout, - auth=auth, ) else: - response = session.request(method, url, headers=headers, data=data, timeout=timeout, auth=auth) + response = session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) except ConnectionError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error From 37a57cd6863da54c4152fbad2af2d5980e903351 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 18:10:52 +0200 Subject: [PATCH 10/36] test: Added tests --- tests/test_ai21_http_client.py | 71 ++++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 17 deletions(-) diff --git a/tests/test_ai21_http_client.py b/tests/test_ai21_http_client.py index eb48b451..942c192a 100644 --- a/tests/test_ai21_http_client.py +++ b/tests/test_ai21_http_client.py @@ -7,6 +7,16 @@ from ai21.version import VERSION _DUMMY_API_KEY = "dummy_key" +_EXPECTED_GET_HEADERS = { + "Authorization": "Bearer dummy_key", + "Content-Type": "application/json", + "User-Agent": f"ai21 studio SDK {VERSION}", +} + +_EXPECTED_POST_FILE_HEADERS = { + "Authorization": "Bearer dummy_key", + "User-Agent": f"ai21 studio SDK {VERSION}", +} class MockResponse: @@ -69,29 +79,56 @@ def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") assert client.get_base_url() == expected_api_host - def test__execute_http_request__when_making_request__should_send_appropriate_parameters( + @pytest.mark.parametrize( + ids=[ + "when_making_request__should_send_appropriate_parameters", + "when_making_request_with_files__should_send_appropriate_post_request", + ], + argnames=["params", "headers"], + argvalues=[ + ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), + ( + {"method": "POST", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}}, + _EXPECTED_POST_FILE_HEADERS, + ), + ], + ) + def test__execute_http_request__( self, + params, + headers, dummy_api_host: str, mock_requests_session: requests.Session, ): - method = "GET" - url = "test_url" - params = {"foo": "bar"} response_json = {"test_key": "test_value"} - headers = { - "Authorization": "Bearer dummy_key", - "Content-Type": "application/json", - "User-Agent": f"ai21 studio SDK {VERSION}", - } client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") mock_requests_session.request.return_value = MockResponse(response_json, 200) - response = client.execute_http_request(method=method, url=url, params=params, session=mock_requests_session) + response = client.execute_http_request(session=mock_requests_session, **params) assert response == response_json - mock_requests_session.request.assert_called_once_with( - method, - url, - headers=headers, - timeout=300, - params=params, - ) + + if "files" in params: + # We split it because when calling requests with "files", "params" is turned into "data" + mock_requests_session.request.assert_called_once_with( + timeout=300, + headers=headers, + files=params["files"], + data=params["params"], + url=params["url"], + method=params["method"], + ) + else: + mock_requests_session.request.assert_called_once_with(timeout=300, headers=headers, **params) + + def test__execute_http_request__when_files_with_put_method__should_raise_value_error( + self, + dummy_api_host: str, + mock_requests_session: requests.Session, + ): + response_json = {"test_key": "test_value"} + + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") + mock_requests_session.request.return_value = MockResponse(response_json, 200) + with pytest.raises(ValueError): + params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} + client.execute_http_request(session=mock_requests_session, **params) From 0b4c6befd7a5d857537f0b867bd34b3f1aeac1ab Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Dec 2023 11:44:31 +0200 Subject: [PATCH 11/36] test: Added sagemaker --- ai21/services/sagemaker.py | 13 +++-- tests/unittests/services/sagemaker_stub.py | 12 +++++ tests/unittests/services/test_sagemaker.py | 56 ++++++++++++++-------- 3 files changed, 57 insertions(+), 24 deletions(-) create mode 100644 tests/unittests/services/sagemaker_stub.py diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index a4c26d3e..b1387622 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -18,9 +18,9 @@ class SageMaker: def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21HTTPClient() + client = cls._create_ai21_http_client() - client.execute_http_request( + response = client.execute_http_request( method="POST", url=f"{client.get_base_url()}/{_GET_ARN_ENDPOINT}", params={ @@ -30,7 +30,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE }, ) - arn = ["arn"] + arn = response["arn"] if not arn: raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version) @@ -40,7 +40,8 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE @classmethod def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21HTTPClient() + + client = cls._create_ai21_http_client() response = client.execute_http_request( method="POST", @@ -53,6 +54,10 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: return response["versions"] + @classmethod + def _create_ai21_http_client(cls) -> AI21HTTPClient: + return AI21HTTPClient() + def _assert_model_package_exists(model_name, region): if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: diff --git a/tests/unittests/services/sagemaker_stub.py b/tests/unittests/services/sagemaker_stub.py new file mode 100644 index 00000000..16cd98c2 --- /dev/null +++ b/tests/unittests/services/sagemaker_stub.py @@ -0,0 +1,12 @@ +from unittest.mock import Mock + +from ai21 import SageMaker +from ai21.ai21_http_client import AI21HTTPClient + + +class SageMakerStub(SageMaker): + ai21_http_client = Mock(spec=AI21HTTPClient) + + @classmethod + def _create_ai21_http_client(cls): + return cls.ai21_http_client diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index 531fb5ef..a92c23fe 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,28 +1,44 @@ import pytest -from pytest_mock import MockerFixture -from ai21 import SageMaker -from ai21.ai21_http_client import AI21HTTPClient -from unittest.mock import patch +from ai21.errors import ModelPackageDoesntExistException +from tests.unittests.services.sagemaker_stub import SageMakerStub - -@pytest.fixture -def mock_ai21_studio_client(mocker: MockerFixture): - return mocker.patch.object( - AI21HTTPClient, - "execute_http_request", - return_value={ - "arn": "some-model-package-id1", - "versions": ["1.0.0", "1.0.1"], - }, - ) +_DUMMY_ARN = "some-model-package-id1" +_DUMMY_VERSIONS = ["1.0.0", "1.0.1"] class TestSageMakerService: - def test__get_model_package_arn__should_return_model_package_arn(self, mocker, mock_ai21_studio_client): - with patch("ai21.ai21_studio_client.AI21StudioClient"): - expected_model_package_arn = "some-model-package-id1" + def test__get_model_package_arn__should_return_model_package_arn(self): + expected_response = { + "arn": _DUMMY_ARN, + "versions": _DUMMY_VERSIONS, + } + SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + + actual_model_package_arn = SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == _DUMMY_ARN + + def test__get_model_package_arn__when_no_arn__should_raise_error(self): + SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} + + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + def test__list_model_package_versions__should_return_model_package_arn(self): + expected_response = { + "versions": _DUMMY_VERSIONS, + } + SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + + actual_model_package_arn = SageMakerStub.list_model_package_versions(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == _DUMMY_VERSIONS - actual_model_package_arn = SageMaker.get_model_package_arn(model_name="j2-mid", region="us-east-1") + def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") - assert actual_model_package_arn == expected_model_package_arn + def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1") From c2f15a1cabfdbccdfaeb6588119b1a42318df7b4 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Dec 2023 13:08:04 +0200 Subject: [PATCH 12/36] test: Created a single session per instance --- ai21/ai21_http_client.py | 20 ++++++++++++----- ai21/http_client.py | 22 +++++++++++++------ tests/{ => unittests}/conftest.py | 0 .../{ => unittests}/test_ai21_http_client.py | 18 ++++++++++----- 4 files changed, 42 insertions(+), 18 deletions(-) rename tests/{ => unittests}/conftest.py (100%) rename tests/{ => unittests}/test_ai21_http_client.py (89%) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 7f6fb66d..9bd3cb82 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,7 +1,6 @@ import io from typing import Optional, Dict, Any -import requests from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyException @@ -41,7 +40,7 @@ def __init__( self._via = via headers = self._build_headers(passed_headers=headers) - self._http_client = http_client or HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) + self._http_client = self._init_http_client(http_client=http_client, headers=headers) def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: headers = { @@ -57,6 +56,18 @@ def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, return headers + def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str, Any]) -> HttpClient: + if http_client is None: + return HttpClient( + timeout_sec=self._timeout_sec, + num_retries=self._num_retries, + headers=headers, + ) + + http_client.add_headers(headers) + + return http_client + def _build_user_agent(self) -> str: user_agent = f"ai21 studio SDK {VERSION}" @@ -77,11 +88,8 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, - session: Optional[requests.Session] = None, ): - return self._http_client.execute_http_request( - method=method, url=url, params=params, files=files, session=session - ) + return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/http_client.py b/ai21/http_client.py index ecf7bd19..00692e07 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,6 +1,6 @@ import io import json -from typing import Optional, Dict +from typing import Optional, Dict, Any import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -56,11 +56,18 @@ def requests_retry_session(session, retries=0): class HttpClient: - def __init__(self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None): + def __init__( + self, + session: Optional[requests.Session] = None, + timeout_sec: int = None, + num_retries: int = None, + headers: Dict = None, + ): self._timeout_sec = timeout_sec or DEFAULT_TIMEOUT_SEC self._num_retries = num_retries or DEFAULT_NUM_RETRIES self._headers = headers or {} self._apply_retry_policy = self._num_retries > 0 + self._session = self._init_session(session) def execute_http_request( self, @@ -68,16 +75,14 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, - session: Optional[requests.Session] = None, ): - session = self._init_session(session) timeout = self._timeout_sec headers = self._headers data = json.dumps(params).encode() logger.info(f"Calling {method} {url} {headers} {data}") try: if method == "GET": - response = session.request( + response = self._session.request( method=method, url=url, headers=headers, @@ -93,7 +98,7 @@ def execute_http_request( headers.pop( "Content-Type" ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload - response = session.request( + response = self._session.request( method=method, url=url, headers=headers, @@ -102,7 +107,7 @@ def execute_http_request( timeout=timeout, ) else: - response = session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) + response = self._session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) except ConnectionError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error @@ -130,3 +135,6 @@ def _init_session(self, session: Optional[requests.Session]) -> requests.Session if self._apply_retry_policy else requests.Session() ) + + def add_headers(self, headers: Dict[str, Any]) -> None: + self._headers.update(headers) diff --git a/tests/conftest.py b/tests/unittests/conftest.py similarity index 100% rename from tests/conftest.py rename to tests/unittests/conftest.py diff --git a/tests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py similarity index 89% rename from tests/test_ai21_http_client.py rename to tests/unittests/test_ai21_http_client.py index 942c192a..67c4f197 100644 --- a/tests/test_ai21_http_client.py +++ b/tests/unittests/test_ai21_http_client.py @@ -4,6 +4,7 @@ import requests from ai21.ai21_http_client import AI21HTTPClient +from ai21.http_client import HttpClient from ai21.version import VERSION _DUMMY_API_KEY = "dummy_key" @@ -101,10 +102,14 @@ def test__execute_http_request__( mock_requests_session: requests.Session, ): response_json = {"test_key": "test_value"} - - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") mock_requests_session.request.return_value = MockResponse(response_json, 200) - response = client.execute_http_request(session=mock_requests_session, **params) + + http_client = HttpClient(session=mock_requests_session) + client = AI21HTTPClient( + http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1" + ) + + response = client.execute_http_request(**params) assert response == response_json if "files" in params: @@ -126,9 +131,12 @@ def test__execute_http_request__when_files_with_put_method__should_raise_value_e mock_requests_session: requests.Session, ): response_json = {"test_key": "test_value"} + http_client = HttpClient(session=mock_requests_session) + client = AI21HTTPClient( + http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1" + ) - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") mock_requests_session.request.return_value = MockResponse(response_json, 200) with pytest.raises(ValueError): params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} - client.execute_http_request(session=mock_requests_session, **params) + client.execute_http_request(**params) From e7cf601fa3db489b3f2b15aff2dec497fb4debda Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Dec 2023 13:13:06 +0200 Subject: [PATCH 13/36] ci: removed unnecessary action --- .github/workflows/quality-checks.yml | 27 --------------------------- 1 file changed, 27 deletions(-) delete mode 100644 .github/workflows/quality-checks.yml diff --git a/.github/workflows/quality-checks.yml b/.github/workflows/quality-checks.yml deleted file mode 100644 index ba62006a..00000000 --- a/.github/workflows/quality-checks.yml +++ /dev/null @@ -1,27 +0,0 @@ -# yaml-language-server: $schema=https://json.schemastore.org/github-workflow.json - -name: Quality Checks -concurrency: - group: Quality-Checks-${{ github.head_ref }} - cancel-in-progress: true -on: - pull_request: -jobs: - quality-checks: - runs-on: ubuntu-20.04 - timeout-minutes: 10 - steps: - - name: Checkout - uses: actions/checkout@v3.2.0 - with: - fetch-depth: 0 - - name: Pre-commit - uses: pre-commit/action@v3.0.0 - with: - extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} - # - name: CODEOWNERS validator - # uses: mszostok/codeowners-validator@v0.6.0 - # with: - # checks: files,duppatterns,syntax,owners - # experimental_checks: notowned - # github_access_token: ${{ secrets.GH_PAT_RO }} From 68bc4567c70e3a78569b172120891f53f2c46136 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:00:10 +0200 Subject: [PATCH 14/36] fix: errors --- README.md | 2 +- ai21/__init__.py | 15 ++++++- ai21/clients/bedrock/bedrock_session.py | 4 +- ai21/errors.py | 53 +++---------------------- ai21/http_client.py | 4 +- 5 files changed, 23 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 03cbbae2..ba75a662 100644 --- a/README.md +++ b/README.md @@ -213,7 +213,7 @@ try: except ai21_errors.AI21ServerError as e: print("Server error and could not be reached") print(e.details) -except ai21_errors.TooManyRequests as e: +except ai21_errors.TooManyRequestsError as e: print("A 429 status code was returned. Slow down on the requests") except AI21APIError as e: print("A non 200 status code error. For more error types see ai21.errors") diff --git a/ai21/__init__.py b/ai21/__init__.py index d614a1ca..8e29e72c 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -1,7 +1,14 @@ from typing import Any from ai21.clients.studio.ai21_client import AI21Client -from ai21.errors import AI21APIError, AI21APITimeoutError +from ai21.errors import ( + AI21APIError, + APITimeoutError, + MissingApiKeyException, + ModelPackageDoesntExistException, + AI21Error, + TooManyRequestsError, +) from ai21.logger import setup_logger from ai21.resources.responses.answer_response import AnswerResponse from ai21.resources.responses.chat_response import ChatResponse @@ -60,7 +67,11 @@ def __getattr__(name: str) -> Any: __all__ = [ "AI21Client", "AI21APIError", - "AI21APITimeoutError", + "APITimeoutError", + "AI21Error", + "MissingApiKeyException", + "ModelPackageDoesntExistException", + "TooManyRequestsError", "AI21BedrockClient", "AI21SageMakerClient", "BedrockModelID", diff --git a/ai21/clients/bedrock/bedrock_session.py b/ai21/clients/bedrock/bedrock_session.py index 7d9f846c..82029da6 100644 --- a/ai21/clients/bedrock/bedrock_session.py +++ b/ai21/clients/bedrock/bedrock_session.py @@ -6,7 +6,7 @@ from botocore.exceptions import ClientError from ai21.logger import logger -from ai21.errors import AccessDenied, NotFound, AI21APITimeoutError +from ai21.errors import AccessDenied, NotFound, APITimeoutError from ai21.http_client import handle_non_success_response _ERROR_MSG_TEMPLATE = ( @@ -52,7 +52,7 @@ def _handle_client_error(self, client_exception: ClientError) -> None: raise NotFound(details=error_message) if status_code == 408: - raise AI21APITimeoutError(details=error_message) + raise APITimeoutError(details=error_message) if status_code == 424: error_message_template = re.compile(_ERROR_MSG_TEMPLATE) diff --git a/ai21/errors.py b/ai21/errors.py index 33cf336b..a72135fb 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -31,7 +31,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(404, details) -class AI21APITimeoutError(AI21APIError): +class APITimeoutError(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(408, details) @@ -41,7 +41,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(422, details) -class TooManyRequests(AI21APIError): +class TooManyRequestsError(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(429, details) @@ -56,7 +56,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(500, details) -class AI21ClientException(Exception): +class AI21Error(Exception): def __init__(self, message: str): self.message = message super().__init__(message) @@ -65,57 +65,14 @@ def __str__(self) -> str: return f"{type(self).__name__} {self.message}" -class MissingInputException(AI21ClientException): - def __init__(self, field_name: str, call_name: str): - message = f"{field_name} is required for the {call_name} call" - super().__init__(message) - - -class UnsupportedInputException(AI21ClientException): - def __init__(self, field_name: str, call_name: str): - message = f"{field_name} is unsupported for the {call_name} call" - super().__init__(message) - - -class UnsupportedDestinationException(AI21ClientException): - def __init__(self, destination_name: str, call_name: str): - message = f'Destination of type {destination_name} is unsupported for the "{call_name}" call' - super().__init__(message) - - -class OnlyOneInputException(AI21ClientException): - def __init__(self, field_name1: str, field_name2: str, call_name: str): - message = f"{field_name1} or {field_name2} is required for the {call_name} call, but not both" - super().__init__(message) - - -class WrongInputTypeException(AI21ClientException): - def __init__(self, key: str, expected_type: type, given_type: type): - message = f"Supplied {key} should be {expected_type}, but {given_type} was passed instead" - super().__init__(message) - - -class EmptyMandatoryListException(AI21ClientException): - def __init__(self, key: str): - message = f"Supplied {key} is empty. At least one element should be present in the list" - super().__init__(message) - - -class MissingApiKeyException(AI21ClientException): +class MissingApiKeyException(AI21Error): def __init__(self): message = "API key must be supplied either globally in the ai21 namespace, or to be provided in the call args" super().__init__(message) self.message = message -class NoSpecifiedRegionException(AI21ClientException): - def __init__(self): - message = "No AWS region provided" - super().__init__(message) - self.message = message - - -class ModelPackageDoesntExistException(AI21ClientException): +class ModelPackageDoesntExistException(AI21Error): def __init__(self, model_name: str, region: str, version: Optional[str] = None): message = f"model_name: {model_name} doesn't exist in region: {region}" diff --git a/ai21/http_client.py b/ai21/http_client.py index 00692e07..c55a5fb1 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -10,7 +10,7 @@ BadRequest, Unauthorized, UnprocessableEntity, - TooManyRequests, + TooManyRequestsError, AI21ServerError, ServiceUnavailable, AI21APIError, @@ -32,7 +32,7 @@ def handle_non_success_response(status_code: int, response_text: str): if status_code == 422: raise UnprocessableEntity(details=response_text) if status_code == 429: - raise TooManyRequests(details=response_text) + raise TooManyRequestsError(details=response_text) if status_code == 500: raise AI21ServerError(details=response_text) if status_code == 503: From 6118385e36a8ca86105f1c49a143113eb3691cda Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:03:29 +0200 Subject: [PATCH 15/36] fix: error renames --- ai21/__init__.py | 8 ++++---- ai21/ai21_http_client.py | 4 ++-- ai21/errors.py | 4 ++-- ai21/services/sagemaker.py | 6 +++--- tests/unittests/services/test_sagemaker.py | 8 ++++---- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/ai21/__init__.py b/ai21/__init__.py index 8e29e72c..6c5fb3e9 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -4,8 +4,8 @@ from ai21.errors import ( AI21APIError, APITimeoutError, - MissingApiKeyException, - ModelPackageDoesntExistException, + MissingApiKeyError, + ModelPackageDoesntExistError, AI21Error, TooManyRequestsError, ) @@ -69,8 +69,8 @@ def __getattr__(name: str) -> Any: "AI21APIError", "APITimeoutError", "AI21Error", - "MissingApiKeyException", - "ModelPackageDoesntExistException", + "MissingApiKeyError", + "ModelPackageDoesntExistError", "TooManyRequestsError", "AI21BedrockClient", "AI21SageMakerClient", diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 9bd3cb82..465787f8 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -3,7 +3,7 @@ from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig -from ai21.errors import MissingApiKeyException +from ai21.errors import MissingApiKeyError from ai21.http_client import HttpClient from ai21.version import VERSION @@ -28,7 +28,7 @@ def __init__( self._api_key = api_key or self._env_config.api_key if self._api_key is None: - raise MissingApiKeyException() + raise MissingApiKeyError() self._api_host = api_host or self._env_config.api_host self._api_version = api_version or self._env_config.api_version diff --git a/ai21/errors.py b/ai21/errors.py index a72135fb..ff4bd921 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -65,14 +65,14 @@ def __str__(self) -> str: return f"{type(self).__name__} {self.message}" -class MissingApiKeyException(AI21Error): +class MissingApiKeyError(AI21Error): def __init__(self): message = "API key must be supplied either globally in the ai21 namespace, or to be provided in the call args" super().__init__(message) self.message = message -class ModelPackageDoesntExistException(AI21Error): +class ModelPackageDoesntExistError(AI21Error): def __init__(self, model_name: str, region: str, version: Optional[str] = None): message = f"model_name: {model_name} doesn't exist in region: {region}" diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index b1387622..f51e1ae2 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -4,7 +4,7 @@ from ai21.clients.sagemaker.constants import ( SAGEMAKER_MODEL_PACKAGE_NAMES, ) -from ai21.errors import ModelPackageDoesntExistException +from ai21.errors import ModelPackageDoesntExistError _JUMPSTART_ENDPOINT = "jumpstart" _LIST_VERSIONS_ENDPOINT = f"{_JUMPSTART_ENDPOINT}/list_versions" @@ -33,7 +33,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE arn = response["arn"] if not arn: - raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version) + raise ModelPackageDoesntExistError(model_name=model_name, region=region, version=version) return arn @@ -61,4 +61,4 @@ def _create_ai21_http_client(cls) -> AI21HTTPClient: def _assert_model_package_exists(model_name, region): if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: - raise ModelPackageDoesntExistException(model_name=model_name, region=region) + raise ModelPackageDoesntExistError(model_name=model_name, region=region) diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index a92c23fe..dd36e1c9 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,6 +1,6 @@ import pytest -from ai21.errors import ModelPackageDoesntExistException +from ai21.errors import ModelPackageDoesntExistError from tests.unittests.services.sagemaker_stub import SageMakerStub _DUMMY_ARN = "some-model-package-id1" @@ -22,7 +22,7 @@ def test__get_model_package_arn__should_return_model_package_arn(self): def test__get_model_package_arn__when_no_arn__should_raise_error(self): SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") def test__list_model_package_versions__should_return_model_package_arn(self): @@ -36,9 +36,9 @@ def test__list_model_package_versions__should_return_model_package_arn(self): assert actual_model_package_arn == _DUMMY_VERSIONS def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1") From 4b3c2f33fd871dbb8f36ca26efacc0e03fad8287 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:36:42 +0200 Subject: [PATCH 16/36] fix: rename upload --- README.md | 2 +- ai21/clients/studio/resources/studio_library.py | 2 +- examples/studio/library.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ba75a662..ac7edc0d 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,7 @@ from ai21 import AI21Client client = AI21Client() -file_id = client.library.files.upload( +file_id = client.library.files.create( file_path="path/to/file", path="path/to/file/in/library", labels=["label1", "label2"], diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index 42daedbb..b8f96a3c 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -20,7 +20,7 @@ def __init__(self, client: AI21HTTPClient): class LibraryFiles(StudioResource): _module_name = "library/files" - def upload( + def create( self, file_path: str, *, diff --git a/examples/studio/library.py b/examples/studio/library.py index e1377200..d693d697 100644 --- a/examples/studio/library.py +++ b/examples/studio/library.py @@ -24,7 +24,7 @@ def validate_file_deleted(): path = os.path.join(file_path, file_name) file_utils.create_file(file_path, file_name, content="test content" * 100) -file_id = client.library.files.upload( +file_id = client.library.files.create( file_path=path, path=file_path, labels=["label1", "label2"], From ae9e8ad2861fdee1c4a9d34385a82ce90fbd82ed Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:38:22 +0200 Subject: [PATCH 17/36] fix: rename type --- ai21/ai21_http_client.py | 6 ++---- ai21/http_client.py | 7 +++---- ai21/resources/studio_resource.py | 5 ++--- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 465787f8..5921a18d 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,6 +1,4 @@ -import io -from typing import Optional, Dict, Any - +from typing import Optional, Dict, Any, BinaryIO from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyError @@ -87,7 +85,7 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ): return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) diff --git a/ai21/http_client.py b/ai21/http_client.py index c55a5fb1..0eeac1a1 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,11 +1,9 @@ -import io import json -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, BinaryIO import requests from requests.adapters import HTTPAdapter, Retry, RetryError -from ai21.logger import logger from ai21.errors import ( BadRequest, Unauthorized, @@ -15,6 +13,7 @@ ServiceUnavailable, AI21APIError, ) +from ai21.logger import logger DEFAULT_TIMEOUT_SEC = 300 DEFAULT_NUM_RETRIES = 0 @@ -74,7 +73,7 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ): timeout = self._timeout_sec headers = self._headers diff --git a/ai21/resources/studio_resource.py b/ai21/resources/studio_resource.py index 7752be91..8ece396e 100644 --- a/ai21/resources/studio_resource.py +++ b/ai21/resources/studio_resource.py @@ -1,8 +1,7 @@ from __future__ import annotations -import io from abc import ABC -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, BinaryIO from ai21.ai21_http_client import AI21HTTPClient @@ -15,7 +14,7 @@ def _post( self, url: str, body: Dict[str, Any], - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ) -> Dict[str, Any]: return self._client.execute_http_request( method="POST", From c55cbeeebb18f91ee47603cb395275e8de22573c Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:39:53 +0200 Subject: [PATCH 18/36] fix: rename variable --- ai21/clients/studio/resources/studio_answer.py | 2 +- ai21/clients/studio/resources/studio_chat.py | 2 +- ai21/clients/studio/resources/studio_dataset.py | 2 +- ai21/resources/bases/answer_base.py | 2 +- ai21/resources/bases/chat_base.py | 2 +- ai21/resources/bases/dataset_base.py | 2 +- examples/studio/dataset.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index ba79621e..5cd12fac 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -15,7 +15,7 @@ def create( mode: Optional[str] = None, **kwargs, ) -> AnswerResponse: - url = f"{self._client.get_base_url()}/{self._MODULE_NAME}" + url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 8fe1bca4..f1dab12b 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -39,6 +39,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, ) - url = f"{self._client.get_base_url()}/{model}/{self._MODULE_NAME}" + url = f"{self._client.get_base_url()}/{model}/{self._module_name}" response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py index 05a07c52..8626d71b 100644 --- a/ai21/clients/studio/resources/studio_dataset.py +++ b/ai21/clients/studio/resources/studio_dataset.py @@ -6,7 +6,7 @@ class StudioDataset(StudioResource, Dataset): - def upload( + def create( self, file_path: str, dataset_name: str, diff --git a/ai21/resources/bases/answer_base.py b/ai21/resources/bases/answer_base.py index 0fbce8c0..4b11ff5c 100644 --- a/ai21/resources/bases/answer_base.py +++ b/ai21/resources/bases/answer_base.py @@ -5,7 +5,7 @@ class Answer(ABC): - _MODULE_NAME = "answer" + _module_name = "answer" def create( self, diff --git a/ai21/resources/bases/chat_base.py b/ai21/resources/bases/chat_base.py index e2a67c0d..f85270ee 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/resources/bases/chat_base.py @@ -11,7 +11,7 @@ class Message: class Chat(ABC): - _MODULE_NAME = "chat" + _module_name = "chat" @abstractmethod def create( diff --git a/ai21/resources/bases/dataset_base.py b/ai21/resources/bases/dataset_base.py index dd53417c..2be49fc7 100644 --- a/ai21/resources/bases/dataset_base.py +++ b/ai21/resources/bases/dataset_base.py @@ -8,7 +8,7 @@ class Dataset(ABC): _module_name = "dataset" @abstractmethod - def upload( + def create( self, file_path: str, dataset_name: str, diff --git a/examples/studio/dataset.py b/examples/studio/dataset.py index b07d6565..87e587cc 100644 --- a/examples/studio/dataset.py +++ b/examples/studio/dataset.py @@ -3,7 +3,7 @@ file_path = "" client = AI21Client() -client.dataset.upload(file_path=file_path, dataset_name="my_new_ds_name") +client.dataset.create(file_path=file_path, dataset_name="my_new_ds_name") result = client.dataset.list() print(result) first_ds_id = result[0].id From e316760e313ffa2e1f98e76b4913eecb29a80d5c Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:40:57 +0200 Subject: [PATCH 19/36] fix: removed experimental --- ai21/clients/studio/resources/studio_completion.py | 5 ----- ai21/resources/bases/completion_base.py | 3 --- 2 files changed, 8 deletions(-) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 48f85fa7..10c1890f 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -18,7 +18,6 @@ def create( top_p: Optional[float] = 1, top_k_return: Optional[int] = 0, custom_model: Optional[str] = None, - experimental_mode: bool = False, stop_sequences: Optional[List[str]] = None, frequency_penalty: Optional[Dict[str, Any]] = None, presence_penalty: Optional[Dict[str, Any]] = None, @@ -26,9 +25,6 @@ def create( epoch: Optional[int] = None, **kwargs, ) -> CompletionsResponse: - if experimental_mode: - model = f"experimental/{model}" - url = f"{self._client.get_base_url()}/{model}" if custom_model is not None: @@ -45,7 +41,6 @@ def create( top_p=top_p, top_k_return=top_k_return, custom_model=custom_model, - experimental_mode=experimental_mode, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, diff --git a/ai21/resources/bases/completion_base.py b/ai21/resources/bases/completion_base.py index cb286df2..f549306a 100644 --- a/ai21/resources/bases/completion_base.py +++ b/ai21/resources/bases/completion_base.py @@ -20,7 +20,6 @@ def create( top_p=1, top_k_return=0, custom_model: Optional[str] = None, - experimental_mode: bool = False, stop_sequences: Optional[List[str]] = (), frequency_penalty: Optional[Dict[str, Any]] = {}, presence_penalty: Optional[Dict[str, Any]] = {}, @@ -44,7 +43,6 @@ def _create_body( top_p: Optional[int], top_k_return: Optional[int], custom_model: Optional[str], - experimental_mode: bool, stop_sequences: Optional[List[str]], frequency_penalty: Optional[Dict[str, Any]], presence_penalty: Optional[Dict[str, Any]], @@ -54,7 +52,6 @@ def _create_body( return { "model": model, "customModel": custom_model, - "experimentalModel": experimental_mode, "prompt": prompt, "maxTokens": max_tokens, "numResults": num_results, From 6dc76143a003e951d0d1c6a657df45b747597e82 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 27 Dec 2023 13:46:42 +0200 Subject: [PATCH 20/36] test: fixed --- ai21/ai21_env_config.py | 4 ---- ai21/ai21_http_client.py | 4 ++-- ai21/clients/studio/resources/studio_improvements.py | 4 ++-- ai21/errors.py | 6 ++++++ tests/unittests/clients/studio/resources/conftest.py | 1 - 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ai21/ai21_env_config.py b/ai21/ai21_env_config.py index 9f3a46be..01ef3501 100644 --- a/ai21/ai21_env_config.py +++ b/ai21/ai21_env_config.py @@ -11,8 +11,6 @@ class _AI21EnvConfig: api_key: Optional[str] = None api_version: str = DEFAULT_API_VERSION api_host: str = STUDIO_HOST - organization: Optional[str] = None - application: Optional[str] = None timeout_sec: Optional[int] = None num_retries: Optional[int] = None aws_region: Optional[str] = None @@ -24,8 +22,6 @@ def from_env(cls) -> _AI21EnvConfig: api_key=os.getenv("AI21_API_KEY"), api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION), api_host=os.getenv("AI21_API_HOST", STUDIO_HOST), - organization=os.getenv("AI21_ORGANIZATION"), - application=os.getenv("AI21_APPLICATION"), timeout_sec=os.getenv("AI21_TIMEOUT_SEC"), num_retries=os.getenv("AI21_NUM_RETRIES"), aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"), diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 5921a18d..68007654 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -33,8 +33,8 @@ def __init__( self._headers = headers self._timeout_sec = timeout_sec or self._env_config.timeout_sec self._num_retries = num_retries or self._env_config.num_retries - self._organization = organization or self._env_config.organization - self._application = application or self._env_config.application + self._organization = organization + self._application = application self._via = via headers = self._build_headers(passed_headers=headers) diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index 50895e24..86287781 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -1,6 +1,6 @@ from typing import List -from ai21.errors import EmptyMandatoryListException +from ai21.errors import EmptyMandatoryListError from ai21.resources.bases.improvements_base import Improvements from ai21.resources.responses.improvement_response import ImprovementsResponse from ai21.resources.studio_resource import StudioResource @@ -9,7 +9,7 @@ class StudioImprovements(StudioResource, Improvements): def create(self, text: str, types: List[str], **kwargs) -> ImprovementsResponse: if len(types) == 0: - raise EmptyMandatoryListException("types") + raise EmptyMandatoryListError("types") url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(text=text, types=types) diff --git a/ai21/errors.py b/ai21/errors.py index ff4bd921..4a0f8c92 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -81,3 +81,9 @@ def __init__(self, model_name: str, region: str, version: Optional[str] = None): super().__init__(message) self.message = message + + +class EmptyMandatoryListError(AI21Error): + def __init__(self, key: str): + message = f"Supplied {key} is empty. At least one element should be present in the list" + super().__init__(message) diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 1a921fef..6d94f2a7 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -98,7 +98,6 @@ def get_studio_completion(): "numResults": 1, "topP": 1, "customModel": None, - "experimentalModel": False, "topKReturn": 0, "stopSequences": [], "frequencyPenalty": None, From 69f50e7dead476b164cec0b6e84dfa9fd7aa2362 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 15:58:42 +0200 Subject: [PATCH 21/36] test: Added some unittests to resources --- ai21/services/sagemaker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index b1387622..3031b5dc 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -20,7 +20,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE client = cls._create_ai21_http_client() - response = client.execute_http_request( + client.execute_http_request( method="POST", url=f"{client.get_base_url()}/{_GET_ARN_ENDPOINT}", params={ @@ -30,7 +30,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE }, ) - arn = response["arn"] + arn = ["arn"] if not arn: raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version) From b6d96ffdcb93788e33a87941464c7d099fcab25a Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 16:52:10 +0200 Subject: [PATCH 22/36] test: Added ai21 studio client tsts --- tests/test_ai21_studio_client.py | 60 ++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/test_ai21_studio_client.py diff --git a/tests/test_ai21_studio_client.py b/tests/test_ai21_studio_client.py new file mode 100644 index 00000000..aaecfc99 --- /dev/null +++ b/tests/test_ai21_studio_client.py @@ -0,0 +1,60 @@ +from typing import Optional + +import pytest + +from ai21.ai21_studio_client import AI21StudioClient +from ai21.version import VERSION + +_DUMMY_API_KEY = "dummy_key" + + +class TestAI21StudioClient: + @pytest.mark.parametrize( + ids=[ + "when_pass_only_via__should_include_via_in_user_agent", + "when_pass_only_application__should_include_application_in_user_agent", + "when_pass_organization__should_include_organization_in_user_agent", + "when_pass_all_user_agent_relevant_params__should_include_them_in_user_agent", + ], + argnames=["via", "application", "organization", "expected_user_agent"], + argvalues=[ + ("langchain", None, None, f"ai21 studio SDK {VERSION} via: langchain"), + (None, "studio", None, f"ai21 studio SDK {VERSION} application: studio"), + (None, None, "ai21", f"ai21 studio SDK {VERSION} organization: ai21"), + ( + "langchain", + "studio", + "ai21", + f"ai21 studio SDK {VERSION} organization: ai21 application: studio via: langchain", + ), + ], + ) + def test__build_headers__user_agent( + self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str + ): + client = AI21StudioClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) + assert client._http_client._headers["User-Agent"] == expected_user_agent + + def test__build_headers__authorization(self): + client = AI21StudioClient(api_key=_DUMMY_API_KEY) + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + def test__build_headers__when_pass_headers__should_append(self): + client = AI21StudioClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) + assert client._http_client._headers["foo"] == "bar" + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + @pytest.mark.parametrize( + ids=[ + "when_api_host_is_not_set__should_return_default", + "when_api_host_is_set__should_return_set_value", + ], + argnames=["api_host", "expected_api_host"], + argvalues=[ + (None, "https://api.ai21.com/studio/v1"), + ("http://test_host", "http://test_host/studio/v1"), + ], + ) + def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): + client = AI21StudioClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") + assert client.get_base_url() == expected_api_host From 9b52690b2b3104ac364bf3058b6c66fccadf90cd Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 17:07:46 +0200 Subject: [PATCH 23/36] fix: rename files --- ai21/services/sagemaker.py | 9 +--- tests/test_ai21_studio_client.py | 10 ++-- tests/unittests/services/test_sagemaker.py | 56 ++++++++-------------- 3 files changed, 27 insertions(+), 48 deletions(-) diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index 3031b5dc..a4c26d3e 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -18,7 +18,7 @@ class SageMaker: def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str: _assert_model_package_exists(model_name=model_name, region=region) - client = cls._create_ai21_http_client() + client = AI21HTTPClient() client.execute_http_request( method="POST", @@ -40,8 +40,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE @classmethod def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: _assert_model_package_exists(model_name=model_name, region=region) - - client = cls._create_ai21_http_client() + client = AI21HTTPClient() response = client.execute_http_request( method="POST", @@ -54,10 +53,6 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: return response["versions"] - @classmethod - def _create_ai21_http_client(cls) -> AI21HTTPClient: - return AI21HTTPClient() - def _assert_model_package_exists(model_name, region): if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: diff --git a/tests/test_ai21_studio_client.py b/tests/test_ai21_studio_client.py index aaecfc99..bbbe2ec6 100644 --- a/tests/test_ai21_studio_client.py +++ b/tests/test_ai21_studio_client.py @@ -2,7 +2,7 @@ import pytest -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.version import VERSION _DUMMY_API_KEY = "dummy_key" @@ -32,15 +32,15 @@ class TestAI21StudioClient: def test__build_headers__user_agent( self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str ): - client = AI21StudioClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) assert client._http_client._headers["User-Agent"] == expected_user_agent def test__build_headers__authorization(self): - client = AI21StudioClient(api_key=_DUMMY_API_KEY) + client = AI21HTTPClient(api_key=_DUMMY_API_KEY) assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" def test__build_headers__when_pass_headers__should_append(self): - client = AI21StudioClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) assert client._http_client._headers["foo"] == "bar" assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" @@ -56,5 +56,5 @@ def test__build_headers__when_pass_headers__should_append(self): ], ) def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): - client = AI21StudioClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") assert client.get_base_url() == expected_api_host diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index a92c23fe..531fb5ef 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,44 +1,28 @@ import pytest +from pytest_mock import MockerFixture -from ai21.errors import ModelPackageDoesntExistException -from tests.unittests.services.sagemaker_stub import SageMakerStub +from ai21 import SageMaker +from ai21.ai21_http_client import AI21HTTPClient +from unittest.mock import patch -_DUMMY_ARN = "some-model-package-id1" -_DUMMY_VERSIONS = ["1.0.0", "1.0.1"] +@pytest.fixture +def mock_ai21_studio_client(mocker: MockerFixture): + return mocker.patch.object( + AI21HTTPClient, + "execute_http_request", + return_value={ + "arn": "some-model-package-id1", + "versions": ["1.0.0", "1.0.1"], + }, + ) -class TestSageMakerService: - def test__get_model_package_arn__should_return_model_package_arn(self): - expected_response = { - "arn": _DUMMY_ARN, - "versions": _DUMMY_VERSIONS, - } - SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response - - actual_model_package_arn = SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") - - assert actual_model_package_arn == _DUMMY_ARN - - def test__get_model_package_arn__when_no_arn__should_raise_error(self): - SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} - with pytest.raises(ModelPackageDoesntExistException): - SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") - - def test__list_model_package_versions__should_return_model_package_arn(self): - expected_response = { - "versions": _DUMMY_VERSIONS, - } - SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response - - actual_model_package_arn = SageMakerStub.list_model_package_versions(model_name="j2-mid", region="us-east-1") - - assert actual_model_package_arn == _DUMMY_VERSIONS +class TestSageMakerService: + def test__get_model_package_arn__should_return_model_package_arn(self, mocker, mock_ai21_studio_client): + with patch("ai21.ai21_studio_client.AI21StudioClient"): + expected_model_package_arn = "some-model-package-id1" - def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistException): - SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") + actual_model_package_arn = SageMaker.get_model_package_arn(model_name="j2-mid", region="us-east-1") - def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistException): - SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1") + assert actual_model_package_arn == expected_model_package_arn From c8cba0a75cddb1a981d4bd44b49d861c72661a72 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 11:55:52 +0200 Subject: [PATCH 24/36] fix: Added types --- ai21/ai21_http_client.py | 15 +------------- ai21/http_client.py | 45 ++++++++++++++-------------------------- 2 files changed, 17 insertions(+), 43 deletions(-) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 9bd3cb82..b58fbb9f 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,7 +1,6 @@ import io from typing import Optional, Dict, Any - from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyException from ai21.http_client import HttpClient @@ -40,7 +39,7 @@ def __init__( self._via = via headers = self._build_headers(passed_headers=headers) - self._http_client = self._init_http_client(http_client=http_client, headers=headers) + self._http_client = http_client or HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: headers = { @@ -56,18 +55,6 @@ def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, return headers - def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str, Any]) -> HttpClient: - if http_client is None: - return HttpClient( - timeout_sec=self._timeout_sec, - num_retries=self._num_retries, - headers=headers, - ) - - http_client.add_headers(headers) - - return http_client - def _build_user_agent(self) -> str: user_agent = f"ai21 studio SDK {VERSION}" diff --git a/ai21/http_client.py b/ai21/http_client.py index 00692e07..e5d9620a 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,6 +1,6 @@ import io import json -from typing import Optional, Dict, Any +from typing import Optional, Dict, Callable, Tuple, Union import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -56,18 +56,11 @@ def requests_retry_session(session, retries=0): class HttpClient: - def __init__( - self, - session: Optional[requests.Session] = None, - timeout_sec: int = None, - num_retries: int = None, - headers: Dict = None, - ): + def __init__(self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None): self._timeout_sec = timeout_sec or DEFAULT_TIMEOUT_SEC self._num_retries = num_retries or DEFAULT_NUM_RETRIES self._headers = headers or {} self._apply_retry_policy = self._num_retries > 0 - self._session = self._init_session(session) def execute_http_request( self, @@ -75,16 +68,22 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, + auth: Optional[Union[Tuple, Callable]] = None, ): + session = ( + requests_retry_session(requests.Session(), retries=self._num_retries) + if self._apply_retry_policy + else requests.Session() + ) timeout = self._timeout_sec headers = self._headers data = json.dumps(params).encode() logger.info(f"Calling {method} {url} {headers} {data}") try: if method == "GET": - response = self._session.request( - method=method, - url=url, + response = session.request( + method, + url, headers=headers, timeout=timeout, params=params, @@ -98,16 +97,17 @@ def execute_http_request( headers.pop( "Content-Type" ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload - response = self._session.request( - method=method, - url=url, + response = session.request( + method, + url, headers=headers, data=params, files=files, timeout=timeout, + auth=auth, ) else: - response = self._session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) + response = session.request(method, url, headers=headers, data=data, timeout=timeout, auth=auth) except ConnectionError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error @@ -125,16 +125,3 @@ def execute_http_request( handle_non_success_response(response.status_code, response.text) return response.json() - - def _init_session(self, session: Optional[requests.Session]) -> requests.Session: - if session is not None: - return session - - return ( - requests_retry_session(requests.Session(), retries=self._num_retries) - if self._apply_retry_policy - else requests.Session() - ) - - def add_headers(self, headers: Dict[str, Any]) -> None: - self._headers.update(headers) From 1d4cf235946424b29c101d05098b5ea3426776f0 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 17:30:56 +0200 Subject: [PATCH 25/36] test: added test to http --- ai21/ai21_http_client.py | 7 +++- ai21/http_client.py | 17 ++++++--- tests/conftest.py | 12 ++++++ ...dio_client.py => test_ai21_http_client.py} | 37 +++++++++++++++++++ 4 files changed, 67 insertions(+), 6 deletions(-) create mode 100644 tests/conftest.py rename tests/{test_ai21_studio_client.py => test_ai21_http_client.py} (67%) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index b58fbb9f..7f6fb66d 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,6 +1,8 @@ import io from typing import Optional, Dict, Any +import requests + from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyException from ai21.http_client import HttpClient @@ -75,8 +77,11 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, + session: Optional[requests.Session] = None, ): - return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) + return self._http_client.execute_http_request( + method=method, url=url, params=params, files=files, session=session + ) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/http_client.py b/ai21/http_client.py index e5d9620a..b8d8a9f0 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -69,12 +69,9 @@ def execute_http_request( params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, auth: Optional[Union[Tuple, Callable]] = None, + session: Optional[requests.Session] = None, ): - session = ( - requests_retry_session(requests.Session(), retries=self._num_retries) - if self._apply_retry_policy - else requests.Session() - ) + session = self._init_session(session) timeout = self._timeout_sec headers = self._headers data = json.dumps(params).encode() @@ -125,3 +122,13 @@ def execute_http_request( handle_non_success_response(response.status_code, response.text) return response.json() + + def _init_session(self, session: Optional[requests.Session]) -> requests.Session: + if session is not None: + return session + + return ( + requests_retry_session(requests.Session(), retries=self._num_retries) + if self._apply_retry_policy + else requests.Session() + ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..02e4d467 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import pytest +import requests + + +@pytest.fixture +def dummy_api_host() -> str: + return "http://test_host" + + +@pytest.fixture +def mock_requests_session(mocker) -> requests.Session: + return mocker.Mock(spec=requests.Session) diff --git a/tests/test_ai21_studio_client.py b/tests/test_ai21_http_client.py similarity index 67% rename from tests/test_ai21_studio_client.py rename to tests/test_ai21_http_client.py index bbbe2ec6..eb48b451 100644 --- a/tests/test_ai21_studio_client.py +++ b/tests/test_ai21_http_client.py @@ -1,6 +1,7 @@ from typing import Optional import pytest +import requests from ai21.ai21_http_client import AI21HTTPClient from ai21.version import VERSION @@ -8,6 +9,15 @@ _DUMMY_API_KEY = "dummy_key" +class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + class TestAI21StudioClient: @pytest.mark.parametrize( ids=[ @@ -58,3 +68,30 @@ def test__build_headers__when_pass_headers__should_append(self): def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") assert client.get_base_url() == expected_api_host + + def test__execute_http_request__when_making_request__should_send_appropriate_parameters( + self, + dummy_api_host: str, + mock_requests_session: requests.Session, + ): + method = "GET" + url = "test_url" + params = {"foo": "bar"} + response_json = {"test_key": "test_value"} + headers = { + "Authorization": "Bearer dummy_key", + "Content-Type": "application/json", + "User-Agent": f"ai21 studio SDK {VERSION}", + } + + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") + mock_requests_session.request.return_value = MockResponse(response_json, 200) + response = client.execute_http_request(method=method, url=url, params=params, session=mock_requests_session) + assert response == response_json + mock_requests_session.request.assert_called_once_with( + method, + url, + headers=headers, + timeout=300, + params=params, + ) From 750b57de7f18a140a4d9392d4d7a8045119130bf Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 18:10:43 +0200 Subject: [PATCH 26/36] fix: removed unnecessary auth param --- ai21/http_client.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/ai21/http_client.py b/ai21/http_client.py index b8d8a9f0..ecf7bd19 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,6 +1,6 @@ import io import json -from typing import Optional, Dict, Callable, Tuple, Union +from typing import Optional, Dict import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -68,7 +68,6 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, - auth: Optional[Union[Tuple, Callable]] = None, session: Optional[requests.Session] = None, ): session = self._init_session(session) @@ -79,8 +78,8 @@ def execute_http_request( try: if method == "GET": response = session.request( - method, - url, + method=method, + url=url, headers=headers, timeout=timeout, params=params, @@ -95,16 +94,15 @@ def execute_http_request( "Content-Type" ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload response = session.request( - method, - url, + method=method, + url=url, headers=headers, data=params, files=files, timeout=timeout, - auth=auth, ) else: - response = session.request(method, url, headers=headers, data=data, timeout=timeout, auth=auth) + response = session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) except ConnectionError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error From c469376addd02da1906284272f2a5945742a2aed Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 18:10:52 +0200 Subject: [PATCH 27/36] test: Added tests --- tests/test_ai21_http_client.py | 71 ++++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 17 deletions(-) diff --git a/tests/test_ai21_http_client.py b/tests/test_ai21_http_client.py index eb48b451..942c192a 100644 --- a/tests/test_ai21_http_client.py +++ b/tests/test_ai21_http_client.py @@ -7,6 +7,16 @@ from ai21.version import VERSION _DUMMY_API_KEY = "dummy_key" +_EXPECTED_GET_HEADERS = { + "Authorization": "Bearer dummy_key", + "Content-Type": "application/json", + "User-Agent": f"ai21 studio SDK {VERSION}", +} + +_EXPECTED_POST_FILE_HEADERS = { + "Authorization": "Bearer dummy_key", + "User-Agent": f"ai21 studio SDK {VERSION}", +} class MockResponse: @@ -69,29 +79,56 @@ def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") assert client.get_base_url() == expected_api_host - def test__execute_http_request__when_making_request__should_send_appropriate_parameters( + @pytest.mark.parametrize( + ids=[ + "when_making_request__should_send_appropriate_parameters", + "when_making_request_with_files__should_send_appropriate_post_request", + ], + argnames=["params", "headers"], + argvalues=[ + ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), + ( + {"method": "POST", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}}, + _EXPECTED_POST_FILE_HEADERS, + ), + ], + ) + def test__execute_http_request__( self, + params, + headers, dummy_api_host: str, mock_requests_session: requests.Session, ): - method = "GET" - url = "test_url" - params = {"foo": "bar"} response_json = {"test_key": "test_value"} - headers = { - "Authorization": "Bearer dummy_key", - "Content-Type": "application/json", - "User-Agent": f"ai21 studio SDK {VERSION}", - } client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") mock_requests_session.request.return_value = MockResponse(response_json, 200) - response = client.execute_http_request(method=method, url=url, params=params, session=mock_requests_session) + response = client.execute_http_request(session=mock_requests_session, **params) assert response == response_json - mock_requests_session.request.assert_called_once_with( - method, - url, - headers=headers, - timeout=300, - params=params, - ) + + if "files" in params: + # We split it because when calling requests with "files", "params" is turned into "data" + mock_requests_session.request.assert_called_once_with( + timeout=300, + headers=headers, + files=params["files"], + data=params["params"], + url=params["url"], + method=params["method"], + ) + else: + mock_requests_session.request.assert_called_once_with(timeout=300, headers=headers, **params) + + def test__execute_http_request__when_files_with_put_method__should_raise_value_error( + self, + dummy_api_host: str, + mock_requests_session: requests.Session, + ): + response_json = {"test_key": "test_value"} + + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") + mock_requests_session.request.return_value = MockResponse(response_json, 200) + with pytest.raises(ValueError): + params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} + client.execute_http_request(session=mock_requests_session, **params) From ad6f5c1e2131de4e6bb2e535f12013245688ae79 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Dec 2023 11:44:31 +0200 Subject: [PATCH 28/36] test: Added sagemaker --- ai21/services/sagemaker.py | 13 +++-- tests/unittests/services/test_sagemaker.py | 56 ++++++++++++++-------- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index a4c26d3e..b1387622 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -18,9 +18,9 @@ class SageMaker: def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21HTTPClient() + client = cls._create_ai21_http_client() - client.execute_http_request( + response = client.execute_http_request( method="POST", url=f"{client.get_base_url()}/{_GET_ARN_ENDPOINT}", params={ @@ -30,7 +30,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE }, ) - arn = ["arn"] + arn = response["arn"] if not arn: raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version) @@ -40,7 +40,8 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE @classmethod def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21HTTPClient() + + client = cls._create_ai21_http_client() response = client.execute_http_request( method="POST", @@ -53,6 +54,10 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: return response["versions"] + @classmethod + def _create_ai21_http_client(cls) -> AI21HTTPClient: + return AI21HTTPClient() + def _assert_model_package_exists(model_name, region): if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index 531fb5ef..a92c23fe 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,28 +1,44 @@ import pytest -from pytest_mock import MockerFixture -from ai21 import SageMaker -from ai21.ai21_http_client import AI21HTTPClient -from unittest.mock import patch +from ai21.errors import ModelPackageDoesntExistException +from tests.unittests.services.sagemaker_stub import SageMakerStub - -@pytest.fixture -def mock_ai21_studio_client(mocker: MockerFixture): - return mocker.patch.object( - AI21HTTPClient, - "execute_http_request", - return_value={ - "arn": "some-model-package-id1", - "versions": ["1.0.0", "1.0.1"], - }, - ) +_DUMMY_ARN = "some-model-package-id1" +_DUMMY_VERSIONS = ["1.0.0", "1.0.1"] class TestSageMakerService: - def test__get_model_package_arn__should_return_model_package_arn(self, mocker, mock_ai21_studio_client): - with patch("ai21.ai21_studio_client.AI21StudioClient"): - expected_model_package_arn = "some-model-package-id1" + def test__get_model_package_arn__should_return_model_package_arn(self): + expected_response = { + "arn": _DUMMY_ARN, + "versions": _DUMMY_VERSIONS, + } + SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + + actual_model_package_arn = SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == _DUMMY_ARN + + def test__get_model_package_arn__when_no_arn__should_raise_error(self): + SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} + + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + def test__list_model_package_versions__should_return_model_package_arn(self): + expected_response = { + "versions": _DUMMY_VERSIONS, + } + SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + + actual_model_package_arn = SageMakerStub.list_model_package_versions(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == _DUMMY_VERSIONS - actual_model_package_arn = SageMaker.get_model_package_arn(model_name="j2-mid", region="us-east-1") + def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") - assert actual_model_package_arn == expected_model_package_arn + def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1") From 826dd57556f00aec2d9d14219f0106c73e488f35 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Dec 2023 13:08:04 +0200 Subject: [PATCH 29/36] test: Created a single session per instance --- ai21/ai21_http_client.py | 20 +++-- ai21/http_client.py | 22 ++++-- tests/conftest.py | 12 --- tests/test_ai21_http_client.py | 134 --------------------------------- 4 files changed, 29 insertions(+), 159 deletions(-) delete mode 100644 tests/conftest.py delete mode 100644 tests/test_ai21_http_client.py diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 7f6fb66d..9bd3cb82 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,7 +1,6 @@ import io from typing import Optional, Dict, Any -import requests from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyException @@ -41,7 +40,7 @@ def __init__( self._via = via headers = self._build_headers(passed_headers=headers) - self._http_client = http_client or HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) + self._http_client = self._init_http_client(http_client=http_client, headers=headers) def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: headers = { @@ -57,6 +56,18 @@ def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, return headers + def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str, Any]) -> HttpClient: + if http_client is None: + return HttpClient( + timeout_sec=self._timeout_sec, + num_retries=self._num_retries, + headers=headers, + ) + + http_client.add_headers(headers) + + return http_client + def _build_user_agent(self) -> str: user_agent = f"ai21 studio SDK {VERSION}" @@ -77,11 +88,8 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, - session: Optional[requests.Session] = None, ): - return self._http_client.execute_http_request( - method=method, url=url, params=params, files=files, session=session - ) + return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/http_client.py b/ai21/http_client.py index ecf7bd19..00692e07 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,6 +1,6 @@ import io import json -from typing import Optional, Dict +from typing import Optional, Dict, Any import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -56,11 +56,18 @@ def requests_retry_session(session, retries=0): class HttpClient: - def __init__(self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None): + def __init__( + self, + session: Optional[requests.Session] = None, + timeout_sec: int = None, + num_retries: int = None, + headers: Dict = None, + ): self._timeout_sec = timeout_sec or DEFAULT_TIMEOUT_SEC self._num_retries = num_retries or DEFAULT_NUM_RETRIES self._headers = headers or {} self._apply_retry_policy = self._num_retries > 0 + self._session = self._init_session(session) def execute_http_request( self, @@ -68,16 +75,14 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, - session: Optional[requests.Session] = None, ): - session = self._init_session(session) timeout = self._timeout_sec headers = self._headers data = json.dumps(params).encode() logger.info(f"Calling {method} {url} {headers} {data}") try: if method == "GET": - response = session.request( + response = self._session.request( method=method, url=url, headers=headers, @@ -93,7 +98,7 @@ def execute_http_request( headers.pop( "Content-Type" ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload - response = session.request( + response = self._session.request( method=method, url=url, headers=headers, @@ -102,7 +107,7 @@ def execute_http_request( timeout=timeout, ) else: - response = session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) + response = self._session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) except ConnectionError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error @@ -130,3 +135,6 @@ def _init_session(self, session: Optional[requests.Session]) -> requests.Session if self._apply_retry_policy else requests.Session() ) + + def add_headers(self, headers: Dict[str, Any]) -> None: + self._headers.update(headers) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 02e4d467..00000000 --- a/tests/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -import pytest -import requests - - -@pytest.fixture -def dummy_api_host() -> str: - return "http://test_host" - - -@pytest.fixture -def mock_requests_session(mocker) -> requests.Session: - return mocker.Mock(spec=requests.Session) diff --git a/tests/test_ai21_http_client.py b/tests/test_ai21_http_client.py deleted file mode 100644 index 942c192a..00000000 --- a/tests/test_ai21_http_client.py +++ /dev/null @@ -1,134 +0,0 @@ -from typing import Optional - -import pytest -import requests - -from ai21.ai21_http_client import AI21HTTPClient -from ai21.version import VERSION - -_DUMMY_API_KEY = "dummy_key" -_EXPECTED_GET_HEADERS = { - "Authorization": "Bearer dummy_key", - "Content-Type": "application/json", - "User-Agent": f"ai21 studio SDK {VERSION}", -} - -_EXPECTED_POST_FILE_HEADERS = { - "Authorization": "Bearer dummy_key", - "User-Agent": f"ai21 studio SDK {VERSION}", -} - - -class MockResponse: - def __init__(self, json_data, status_code): - self.json_data = json_data - self.status_code = status_code - - def json(self): - return self.json_data - - -class TestAI21StudioClient: - @pytest.mark.parametrize( - ids=[ - "when_pass_only_via__should_include_via_in_user_agent", - "when_pass_only_application__should_include_application_in_user_agent", - "when_pass_organization__should_include_organization_in_user_agent", - "when_pass_all_user_agent_relevant_params__should_include_them_in_user_agent", - ], - argnames=["via", "application", "organization", "expected_user_agent"], - argvalues=[ - ("langchain", None, None, f"ai21 studio SDK {VERSION} via: langchain"), - (None, "studio", None, f"ai21 studio SDK {VERSION} application: studio"), - (None, None, "ai21", f"ai21 studio SDK {VERSION} organization: ai21"), - ( - "langchain", - "studio", - "ai21", - f"ai21 studio SDK {VERSION} organization: ai21 application: studio via: langchain", - ), - ], - ) - def test__build_headers__user_agent( - self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str - ): - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) - assert client._http_client._headers["User-Agent"] == expected_user_agent - - def test__build_headers__authorization(self): - client = AI21HTTPClient(api_key=_DUMMY_API_KEY) - assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" - - def test__build_headers__when_pass_headers__should_append(self): - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) - assert client._http_client._headers["foo"] == "bar" - assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" - - @pytest.mark.parametrize( - ids=[ - "when_api_host_is_not_set__should_return_default", - "when_api_host_is_set__should_return_set_value", - ], - argnames=["api_host", "expected_api_host"], - argvalues=[ - (None, "https://api.ai21.com/studio/v1"), - ("http://test_host", "http://test_host/studio/v1"), - ], - ) - def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") - assert client.get_base_url() == expected_api_host - - @pytest.mark.parametrize( - ids=[ - "when_making_request__should_send_appropriate_parameters", - "when_making_request_with_files__should_send_appropriate_post_request", - ], - argnames=["params", "headers"], - argvalues=[ - ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), - ( - {"method": "POST", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}}, - _EXPECTED_POST_FILE_HEADERS, - ), - ], - ) - def test__execute_http_request__( - self, - params, - headers, - dummy_api_host: str, - mock_requests_session: requests.Session, - ): - response_json = {"test_key": "test_value"} - - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") - mock_requests_session.request.return_value = MockResponse(response_json, 200) - response = client.execute_http_request(session=mock_requests_session, **params) - assert response == response_json - - if "files" in params: - # We split it because when calling requests with "files", "params" is turned into "data" - mock_requests_session.request.assert_called_once_with( - timeout=300, - headers=headers, - files=params["files"], - data=params["params"], - url=params["url"], - method=params["method"], - ) - else: - mock_requests_session.request.assert_called_once_with(timeout=300, headers=headers, **params) - - def test__execute_http_request__when_files_with_put_method__should_raise_value_error( - self, - dummy_api_host: str, - mock_requests_session: requests.Session, - ): - response_json = {"test_key": "test_value"} - - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") - mock_requests_session.request.return_value = MockResponse(response_json, 200) - with pytest.raises(ValueError): - params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} - client.execute_http_request(session=mock_requests_session, **params) From d1cbea18d204ec6ea636a51e0dbc51a0609c0b77 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:00:10 +0200 Subject: [PATCH 30/36] fix: errors --- README.md | 2 +- ai21/__init__.py | 15 ++++++- ai21/clients/bedrock/bedrock_session.py | 4 +- ai21/errors.py | 53 +++---------------------- ai21/http_client.py | 4 +- 5 files changed, 23 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 03cbbae2..ba75a662 100644 --- a/README.md +++ b/README.md @@ -213,7 +213,7 @@ try: except ai21_errors.AI21ServerError as e: print("Server error and could not be reached") print(e.details) -except ai21_errors.TooManyRequests as e: +except ai21_errors.TooManyRequestsError as e: print("A 429 status code was returned. Slow down on the requests") except AI21APIError as e: print("A non 200 status code error. For more error types see ai21.errors") diff --git a/ai21/__init__.py b/ai21/__init__.py index d614a1ca..8e29e72c 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -1,7 +1,14 @@ from typing import Any from ai21.clients.studio.ai21_client import AI21Client -from ai21.errors import AI21APIError, AI21APITimeoutError +from ai21.errors import ( + AI21APIError, + APITimeoutError, + MissingApiKeyException, + ModelPackageDoesntExistException, + AI21Error, + TooManyRequestsError, +) from ai21.logger import setup_logger from ai21.resources.responses.answer_response import AnswerResponse from ai21.resources.responses.chat_response import ChatResponse @@ -60,7 +67,11 @@ def __getattr__(name: str) -> Any: __all__ = [ "AI21Client", "AI21APIError", - "AI21APITimeoutError", + "APITimeoutError", + "AI21Error", + "MissingApiKeyException", + "ModelPackageDoesntExistException", + "TooManyRequestsError", "AI21BedrockClient", "AI21SageMakerClient", "BedrockModelID", diff --git a/ai21/clients/bedrock/bedrock_session.py b/ai21/clients/bedrock/bedrock_session.py index 7d9f846c..82029da6 100644 --- a/ai21/clients/bedrock/bedrock_session.py +++ b/ai21/clients/bedrock/bedrock_session.py @@ -6,7 +6,7 @@ from botocore.exceptions import ClientError from ai21.logger import logger -from ai21.errors import AccessDenied, NotFound, AI21APITimeoutError +from ai21.errors import AccessDenied, NotFound, APITimeoutError from ai21.http_client import handle_non_success_response _ERROR_MSG_TEMPLATE = ( @@ -52,7 +52,7 @@ def _handle_client_error(self, client_exception: ClientError) -> None: raise NotFound(details=error_message) if status_code == 408: - raise AI21APITimeoutError(details=error_message) + raise APITimeoutError(details=error_message) if status_code == 424: error_message_template = re.compile(_ERROR_MSG_TEMPLATE) diff --git a/ai21/errors.py b/ai21/errors.py index 33cf336b..a72135fb 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -31,7 +31,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(404, details) -class AI21APITimeoutError(AI21APIError): +class APITimeoutError(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(408, details) @@ -41,7 +41,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(422, details) -class TooManyRequests(AI21APIError): +class TooManyRequestsError(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(429, details) @@ -56,7 +56,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(500, details) -class AI21ClientException(Exception): +class AI21Error(Exception): def __init__(self, message: str): self.message = message super().__init__(message) @@ -65,57 +65,14 @@ def __str__(self) -> str: return f"{type(self).__name__} {self.message}" -class MissingInputException(AI21ClientException): - def __init__(self, field_name: str, call_name: str): - message = f"{field_name} is required for the {call_name} call" - super().__init__(message) - - -class UnsupportedInputException(AI21ClientException): - def __init__(self, field_name: str, call_name: str): - message = f"{field_name} is unsupported for the {call_name} call" - super().__init__(message) - - -class UnsupportedDestinationException(AI21ClientException): - def __init__(self, destination_name: str, call_name: str): - message = f'Destination of type {destination_name} is unsupported for the "{call_name}" call' - super().__init__(message) - - -class OnlyOneInputException(AI21ClientException): - def __init__(self, field_name1: str, field_name2: str, call_name: str): - message = f"{field_name1} or {field_name2} is required for the {call_name} call, but not both" - super().__init__(message) - - -class WrongInputTypeException(AI21ClientException): - def __init__(self, key: str, expected_type: type, given_type: type): - message = f"Supplied {key} should be {expected_type}, but {given_type} was passed instead" - super().__init__(message) - - -class EmptyMandatoryListException(AI21ClientException): - def __init__(self, key: str): - message = f"Supplied {key} is empty. At least one element should be present in the list" - super().__init__(message) - - -class MissingApiKeyException(AI21ClientException): +class MissingApiKeyException(AI21Error): def __init__(self): message = "API key must be supplied either globally in the ai21 namespace, or to be provided in the call args" super().__init__(message) self.message = message -class NoSpecifiedRegionException(AI21ClientException): - def __init__(self): - message = "No AWS region provided" - super().__init__(message) - self.message = message - - -class ModelPackageDoesntExistException(AI21ClientException): +class ModelPackageDoesntExistException(AI21Error): def __init__(self, model_name: str, region: str, version: Optional[str] = None): message = f"model_name: {model_name} doesn't exist in region: {region}" diff --git a/ai21/http_client.py b/ai21/http_client.py index 00692e07..c55a5fb1 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -10,7 +10,7 @@ BadRequest, Unauthorized, UnprocessableEntity, - TooManyRequests, + TooManyRequestsError, AI21ServerError, ServiceUnavailable, AI21APIError, @@ -32,7 +32,7 @@ def handle_non_success_response(status_code: int, response_text: str): if status_code == 422: raise UnprocessableEntity(details=response_text) if status_code == 429: - raise TooManyRequests(details=response_text) + raise TooManyRequestsError(details=response_text) if status_code == 500: raise AI21ServerError(details=response_text) if status_code == 503: From 6283a991b55d4157ef9ccc5232f5eeda4b084747 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:03:29 +0200 Subject: [PATCH 31/36] fix: error renames --- ai21/__init__.py | 8 ++++---- ai21/ai21_http_client.py | 4 ++-- ai21/errors.py | 4 ++-- ai21/services/sagemaker.py | 6 +++--- tests/unittests/services/test_sagemaker.py | 8 ++++---- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/ai21/__init__.py b/ai21/__init__.py index 8e29e72c..6c5fb3e9 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -4,8 +4,8 @@ from ai21.errors import ( AI21APIError, APITimeoutError, - MissingApiKeyException, - ModelPackageDoesntExistException, + MissingApiKeyError, + ModelPackageDoesntExistError, AI21Error, TooManyRequestsError, ) @@ -69,8 +69,8 @@ def __getattr__(name: str) -> Any: "AI21APIError", "APITimeoutError", "AI21Error", - "MissingApiKeyException", - "ModelPackageDoesntExistException", + "MissingApiKeyError", + "ModelPackageDoesntExistError", "TooManyRequestsError", "AI21BedrockClient", "AI21SageMakerClient", diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 9bd3cb82..465787f8 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -3,7 +3,7 @@ from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig -from ai21.errors import MissingApiKeyException +from ai21.errors import MissingApiKeyError from ai21.http_client import HttpClient from ai21.version import VERSION @@ -28,7 +28,7 @@ def __init__( self._api_key = api_key or self._env_config.api_key if self._api_key is None: - raise MissingApiKeyException() + raise MissingApiKeyError() self._api_host = api_host or self._env_config.api_host self._api_version = api_version or self._env_config.api_version diff --git a/ai21/errors.py b/ai21/errors.py index a72135fb..ff4bd921 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -65,14 +65,14 @@ def __str__(self) -> str: return f"{type(self).__name__} {self.message}" -class MissingApiKeyException(AI21Error): +class MissingApiKeyError(AI21Error): def __init__(self): message = "API key must be supplied either globally in the ai21 namespace, or to be provided in the call args" super().__init__(message) self.message = message -class ModelPackageDoesntExistException(AI21Error): +class ModelPackageDoesntExistError(AI21Error): def __init__(self, model_name: str, region: str, version: Optional[str] = None): message = f"model_name: {model_name} doesn't exist in region: {region}" diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index b1387622..f51e1ae2 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -4,7 +4,7 @@ from ai21.clients.sagemaker.constants import ( SAGEMAKER_MODEL_PACKAGE_NAMES, ) -from ai21.errors import ModelPackageDoesntExistException +from ai21.errors import ModelPackageDoesntExistError _JUMPSTART_ENDPOINT = "jumpstart" _LIST_VERSIONS_ENDPOINT = f"{_JUMPSTART_ENDPOINT}/list_versions" @@ -33,7 +33,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE arn = response["arn"] if not arn: - raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version) + raise ModelPackageDoesntExistError(model_name=model_name, region=region, version=version) return arn @@ -61,4 +61,4 @@ def _create_ai21_http_client(cls) -> AI21HTTPClient: def _assert_model_package_exists(model_name, region): if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: - raise ModelPackageDoesntExistException(model_name=model_name, region=region) + raise ModelPackageDoesntExistError(model_name=model_name, region=region) diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index a92c23fe..dd36e1c9 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,6 +1,6 @@ import pytest -from ai21.errors import ModelPackageDoesntExistException +from ai21.errors import ModelPackageDoesntExistError from tests.unittests.services.sagemaker_stub import SageMakerStub _DUMMY_ARN = "some-model-package-id1" @@ -22,7 +22,7 @@ def test__get_model_package_arn__should_return_model_package_arn(self): def test__get_model_package_arn__when_no_arn__should_raise_error(self): SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") def test__list_model_package_versions__should_return_model_package_arn(self): @@ -36,9 +36,9 @@ def test__list_model_package_versions__should_return_model_package_arn(self): assert actual_model_package_arn == _DUMMY_VERSIONS def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1") From 4bbd6a28b2104cd0d8dbb806eea42b6277c9c85e Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:36:42 +0200 Subject: [PATCH 32/36] fix: rename upload --- README.md | 2 +- ai21/clients/studio/resources/studio_library.py | 2 +- examples/studio/library.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ba75a662..ac7edc0d 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,7 @@ from ai21 import AI21Client client = AI21Client() -file_id = client.library.files.upload( +file_id = client.library.files.create( file_path="path/to/file", path="path/to/file/in/library", labels=["label1", "label2"], diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index 42daedbb..b8f96a3c 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -20,7 +20,7 @@ def __init__(self, client: AI21HTTPClient): class LibraryFiles(StudioResource): _module_name = "library/files" - def upload( + def create( self, file_path: str, *, diff --git a/examples/studio/library.py b/examples/studio/library.py index e1377200..d693d697 100644 --- a/examples/studio/library.py +++ b/examples/studio/library.py @@ -24,7 +24,7 @@ def validate_file_deleted(): path = os.path.join(file_path, file_name) file_utils.create_file(file_path, file_name, content="test content" * 100) -file_id = client.library.files.upload( +file_id = client.library.files.create( file_path=path, path=file_path, labels=["label1", "label2"], From 4afdfce42960a41ef766807e790634e8bf2b8fc2 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:38:22 +0200 Subject: [PATCH 33/36] fix: rename type --- ai21/ai21_http_client.py | 6 ++---- ai21/http_client.py | 7 +++---- ai21/resources/studio_resource.py | 5 ++--- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 465787f8..5921a18d 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,6 +1,4 @@ -import io -from typing import Optional, Dict, Any - +from typing import Optional, Dict, Any, BinaryIO from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyError @@ -87,7 +85,7 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ): return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) diff --git a/ai21/http_client.py b/ai21/http_client.py index c55a5fb1..0eeac1a1 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,11 +1,9 @@ -import io import json -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, BinaryIO import requests from requests.adapters import HTTPAdapter, Retry, RetryError -from ai21.logger import logger from ai21.errors import ( BadRequest, Unauthorized, @@ -15,6 +13,7 @@ ServiceUnavailable, AI21APIError, ) +from ai21.logger import logger DEFAULT_TIMEOUT_SEC = 300 DEFAULT_NUM_RETRIES = 0 @@ -74,7 +73,7 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ): timeout = self._timeout_sec headers = self._headers diff --git a/ai21/resources/studio_resource.py b/ai21/resources/studio_resource.py index 7752be91..8ece396e 100644 --- a/ai21/resources/studio_resource.py +++ b/ai21/resources/studio_resource.py @@ -1,8 +1,7 @@ from __future__ import annotations -import io from abc import ABC -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, BinaryIO from ai21.ai21_http_client import AI21HTTPClient @@ -15,7 +14,7 @@ def _post( self, url: str, body: Dict[str, Any], - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ) -> Dict[str, Any]: return self._client.execute_http_request( method="POST", From 70bdd9a19aebf4a158b7e24d6596b0021048de8c Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:39:53 +0200 Subject: [PATCH 34/36] fix: rename variable --- ai21/clients/studio/resources/studio_answer.py | 2 +- ai21/clients/studio/resources/studio_chat.py | 2 +- ai21/clients/studio/resources/studio_dataset.py | 2 +- ai21/resources/bases/answer_base.py | 2 +- ai21/resources/bases/chat_base.py | 2 +- ai21/resources/bases/dataset_base.py | 2 +- examples/studio/dataset.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index ba79621e..5cd12fac 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -15,7 +15,7 @@ def create( mode: Optional[str] = None, **kwargs, ) -> AnswerResponse: - url = f"{self._client.get_base_url()}/{self._MODULE_NAME}" + url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 8fe1bca4..f1dab12b 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -39,6 +39,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, ) - url = f"{self._client.get_base_url()}/{model}/{self._MODULE_NAME}" + url = f"{self._client.get_base_url()}/{model}/{self._module_name}" response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py index 05a07c52..8626d71b 100644 --- a/ai21/clients/studio/resources/studio_dataset.py +++ b/ai21/clients/studio/resources/studio_dataset.py @@ -6,7 +6,7 @@ class StudioDataset(StudioResource, Dataset): - def upload( + def create( self, file_path: str, dataset_name: str, diff --git a/ai21/resources/bases/answer_base.py b/ai21/resources/bases/answer_base.py index 0fbce8c0..4b11ff5c 100644 --- a/ai21/resources/bases/answer_base.py +++ b/ai21/resources/bases/answer_base.py @@ -5,7 +5,7 @@ class Answer(ABC): - _MODULE_NAME = "answer" + _module_name = "answer" def create( self, diff --git a/ai21/resources/bases/chat_base.py b/ai21/resources/bases/chat_base.py index e2a67c0d..f85270ee 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/resources/bases/chat_base.py @@ -11,7 +11,7 @@ class Message: class Chat(ABC): - _MODULE_NAME = "chat" + _module_name = "chat" @abstractmethod def create( diff --git a/ai21/resources/bases/dataset_base.py b/ai21/resources/bases/dataset_base.py index dd53417c..2be49fc7 100644 --- a/ai21/resources/bases/dataset_base.py +++ b/ai21/resources/bases/dataset_base.py @@ -8,7 +8,7 @@ class Dataset(ABC): _module_name = "dataset" @abstractmethod - def upload( + def create( self, file_path: str, dataset_name: str, diff --git a/examples/studio/dataset.py b/examples/studio/dataset.py index b07d6565..87e587cc 100644 --- a/examples/studio/dataset.py +++ b/examples/studio/dataset.py @@ -3,7 +3,7 @@ file_path = "" client = AI21Client() -client.dataset.upload(file_path=file_path, dataset_name="my_new_ds_name") +client.dataset.create(file_path=file_path, dataset_name="my_new_ds_name") result = client.dataset.list() print(result) first_ds_id = result[0].id From 3d325d929e91dc327ef547b650c7e2d3a9f6de1a Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Dec 2023 17:40:57 +0200 Subject: [PATCH 35/36] fix: removed experimental --- ai21/clients/studio/resources/studio_completion.py | 5 ----- ai21/resources/bases/completion_base.py | 3 --- 2 files changed, 8 deletions(-) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 48f85fa7..10c1890f 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -18,7 +18,6 @@ def create( top_p: Optional[float] = 1, top_k_return: Optional[int] = 0, custom_model: Optional[str] = None, - experimental_mode: bool = False, stop_sequences: Optional[List[str]] = None, frequency_penalty: Optional[Dict[str, Any]] = None, presence_penalty: Optional[Dict[str, Any]] = None, @@ -26,9 +25,6 @@ def create( epoch: Optional[int] = None, **kwargs, ) -> CompletionsResponse: - if experimental_mode: - model = f"experimental/{model}" - url = f"{self._client.get_base_url()}/{model}" if custom_model is not None: @@ -45,7 +41,6 @@ def create( top_p=top_p, top_k_return=top_k_return, custom_model=custom_model, - experimental_mode=experimental_mode, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, diff --git a/ai21/resources/bases/completion_base.py b/ai21/resources/bases/completion_base.py index cb286df2..f549306a 100644 --- a/ai21/resources/bases/completion_base.py +++ b/ai21/resources/bases/completion_base.py @@ -20,7 +20,6 @@ def create( top_p=1, top_k_return=0, custom_model: Optional[str] = None, - experimental_mode: bool = False, stop_sequences: Optional[List[str]] = (), frequency_penalty: Optional[Dict[str, Any]] = {}, presence_penalty: Optional[Dict[str, Any]] = {}, @@ -44,7 +43,6 @@ def _create_body( top_p: Optional[int], top_k_return: Optional[int], custom_model: Optional[str], - experimental_mode: bool, stop_sequences: Optional[List[str]], frequency_penalty: Optional[Dict[str, Any]], presence_penalty: Optional[Dict[str, Any]], @@ -54,7 +52,6 @@ def _create_body( return { "model": model, "customModel": custom_model, - "experimentalModel": experimental_mode, "prompt": prompt, "maxTokens": max_tokens, "numResults": num_results, From c8bcf10fa9fae443c17e79c898f194efb7ef7f5e Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 27 Dec 2023 13:46:42 +0200 Subject: [PATCH 36/36] test: fixed --- ai21/ai21_env_config.py | 4 ---- ai21/ai21_http_client.py | 4 ++-- ai21/clients/studio/resources/studio_improvements.py | 4 ++-- ai21/errors.py | 6 ++++++ tests/unittests/clients/studio/resources/conftest.py | 1 - 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ai21/ai21_env_config.py b/ai21/ai21_env_config.py index 9f3a46be..01ef3501 100644 --- a/ai21/ai21_env_config.py +++ b/ai21/ai21_env_config.py @@ -11,8 +11,6 @@ class _AI21EnvConfig: api_key: Optional[str] = None api_version: str = DEFAULT_API_VERSION api_host: str = STUDIO_HOST - organization: Optional[str] = None - application: Optional[str] = None timeout_sec: Optional[int] = None num_retries: Optional[int] = None aws_region: Optional[str] = None @@ -24,8 +22,6 @@ def from_env(cls) -> _AI21EnvConfig: api_key=os.getenv("AI21_API_KEY"), api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION), api_host=os.getenv("AI21_API_HOST", STUDIO_HOST), - organization=os.getenv("AI21_ORGANIZATION"), - application=os.getenv("AI21_APPLICATION"), timeout_sec=os.getenv("AI21_TIMEOUT_SEC"), num_retries=os.getenv("AI21_NUM_RETRIES"), aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"), diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 5921a18d..68007654 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -33,8 +33,8 @@ def __init__( self._headers = headers self._timeout_sec = timeout_sec or self._env_config.timeout_sec self._num_retries = num_retries or self._env_config.num_retries - self._organization = organization or self._env_config.organization - self._application = application or self._env_config.application + self._organization = organization + self._application = application self._via = via headers = self._build_headers(passed_headers=headers) diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index 50895e24..86287781 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -1,6 +1,6 @@ from typing import List -from ai21.errors import EmptyMandatoryListException +from ai21.errors import EmptyMandatoryListError from ai21.resources.bases.improvements_base import Improvements from ai21.resources.responses.improvement_response import ImprovementsResponse from ai21.resources.studio_resource import StudioResource @@ -9,7 +9,7 @@ class StudioImprovements(StudioResource, Improvements): def create(self, text: str, types: List[str], **kwargs) -> ImprovementsResponse: if len(types) == 0: - raise EmptyMandatoryListException("types") + raise EmptyMandatoryListError("types") url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(text=text, types=types) diff --git a/ai21/errors.py b/ai21/errors.py index ff4bd921..4a0f8c92 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -81,3 +81,9 @@ def __init__(self, model_name: str, region: str, version: Optional[str] = None): super().__init__(message) self.message = message + + +class EmptyMandatoryListError(AI21Error): + def __init__(self, key: str): + message = f"Supplied {key} is empty. At least one element should be present in the list" + super().__init__(message) diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 1a921fef..6d94f2a7 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -98,7 +98,6 @@ def get_studio_completion(): "numResults": 1, "topP": 1, "customModel": None, - "experimentalModel": False, "topKReturn": 0, "stopSequences": [], "frequencyPenalty": None,