From a7be77551de89b22e6db4917cbdbe2f3fc2868d1 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 09:36:16 +0200 Subject: [PATCH 01/13] test: get_tokenizer tests --- ai21/tokenizers/factory.py | 4 +-- tests/unittests/test_dummy.py | 2 -- tests/unittests/tokenizers/__init__.py | 0 .../tokenizers/test_ai21_tokenizer.py | 25 +++++++++++++++++++ 4 files changed, 27 insertions(+), 4 deletions(-) delete mode 100644 tests/unittests/test_dummy.py create mode 100644 tests/unittests/tokenizers/__init__.py create mode 100644 tests/unittests/tokenizers/test_ai21_tokenizer.py diff --git a/ai21/tokenizers/factory.py b/ai21/tokenizers/factory.py index fd229231..cd728f77 100644 --- a/ai21/tokenizers/factory.py +++ b/ai21/tokenizers/factory.py @@ -16,6 +16,6 @@ def get_tokenizer() -> AI21Tokenizer: global _cached_tokenizer if _cached_tokenizer is None: - _cached_tokenizer = Tokenizer.get_tokenizer() + _cached_tokenizer = AI21Tokenizer(Tokenizer.get_tokenizer()) - return AI21Tokenizer(_cached_tokenizer) + return _cached_tokenizer diff --git a/tests/unittests/test_dummy.py b/tests/unittests/test_dummy.py deleted file mode 100644 index 39b433bc..00000000 --- a/tests/unittests/test_dummy.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_assert(): - assert True diff --git a/tests/unittests/tokenizers/__init__.py b/tests/unittests/tokenizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/tokenizers/test_ai21_tokenizer.py b/tests/unittests/tokenizers/test_ai21_tokenizer.py new file mode 100644 index 00000000..33f89d9e --- /dev/null +++ b/tests/unittests/tokenizers/test_ai21_tokenizer.py @@ -0,0 +1,25 @@ +from ai21.tokenizers.factory import get_tokenizer + + +class TestAI21Tokenizer: + def test__count_tokens__should_return_number_of_tokens(self): + expected_number_of_tokens = 8 + tokenizer = get_tokenizer() + + actual_number_of_tokens = tokenizer.count_tokens("Text to Tokenize - Hello world!") + + assert actual_number_of_tokens == expected_number_of_tokens + + def test__tokenize__should_return_list_of_tokens(self): + expected_tokens = ["▁Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"] + tokenizer = get_tokenizer() + + actual_tokens = tokenizer.tokenize("Text to Tokenize - Hello world!") + + assert actual_tokens == expected_tokens + + def test__tokenizer__should_be_singleton__when_called_twice(self): + tokenizer1 = get_tokenizer() + tokenizer2 = get_tokenizer() + + assert tokenizer1 is tokenizer2 From 22fee638ff21be59c6fc6db45fe6844bc193507c Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 15:57:35 +0200 Subject: [PATCH 02/13] fix: cases --- ai21/clients/studio/resources/studio_chat.py | 2 +- ai21/clients/studio/resources/studio_completion.py | 2 +- ai21/resources/bases/chat_base.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index f1dab12b..8fe1bca4 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -39,6 +39,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, ) - url = f"{self._client.get_base_url()}/{model}/{self._module_name}" + url = f"{self._client.get_base_url()}/{model}/{self._MODULE_NAME}" response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 84e179b3..48f85fa7 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -15,7 +15,7 @@ def create( num_results: Optional[int] = 1, min_tokens: Optional[int] = 0, temperature: Optional[float] = 0.7, - top_p: Optional[int] = 1, + top_p: Optional[float] = 1, top_k_return: Optional[int] = 0, custom_model: Optional[str] = None, experimental_mode: bool = False, diff --git a/ai21/resources/bases/chat_base.py b/ai21/resources/bases/chat_base.py index f85270ee..e2a67c0d 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/resources/bases/chat_base.py @@ -11,7 +11,7 @@ class Message: class Chat(ABC): - _module_name = "chat" + _MODULE_NAME = "chat" @abstractmethod def create( From 68fd37e9640e6bb1eab9db92e74142d437e58700 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 15:58:42 +0200 Subject: [PATCH 03/13] test: Added some unittests to resources --- ai21/services/sagemaker.py | 4 +- poetry.lock | 19 ++- pyproject.toml | 1 + tests/unittests/clients/__init__.py | 0 tests/unittests/clients/studio/__init__.py | 0 .../clients/studio/resources/__init__.py | 0 .../clients/studio/resources/conftest.py | 122 ++++++++++++++++++ .../studio/resources/test_studio_resources.py | 80 ++++++++++++ tests/unittests/services/__init__.py | 0 tests/unittests/services/test_sagemaker.py | 28 ++++ 10 files changed, 251 insertions(+), 3 deletions(-) create mode 100644 tests/unittests/clients/__init__.py create mode 100644 tests/unittests/clients/studio/__init__.py create mode 100644 tests/unittests/clients/studio/resources/__init__.py create mode 100644 tests/unittests/clients/studio/resources/conftest.py create mode 100644 tests/unittests/clients/studio/resources/test_studio_resources.py create mode 100644 tests/unittests/services/__init__.py create mode 100644 tests/unittests/services/test_sagemaker.py diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index fa811541..2d9c6469 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -20,7 +20,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE client = AI21StudioClient() - response = client.execute_http_request( + client.execute_http_request( method="POST", url=f"{client.get_base_url()}/{_GET_ARN_ENDPOINT}", params={ @@ -30,7 +30,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE }, ) - arn = response["arn"] + arn = ["arn"] if not arn: raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version) diff --git a/poetry.lock b/poetry.lock index 334d859e..cd17ee01 100644 --- a/poetry.lock +++ b/poetry.lock @@ -848,6 +848,23 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-mock" +version = "3.12.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, + {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, +] + +[package.dependencies] +pytest = ">=5.0" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -1312,4 +1329,4 @@ aws = ["boto3"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "71ce6369e72538e571ac954b5ebb4e66fa79c1752aa61af336144df577078cc4" +content-hash = "39ea6a4fd93efce593b30be52de954f1d6ab4c2d39745a9541067a5af5f37a21" diff --git a/pyproject.toml b/pyproject.toml index 120b38ae..060d165b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ safety = "*" ruff = "*" python-semantic-release = "^8.5.0" pytest = "^7.4.3" +pytest-mock = "^3.12.0" [tool.poetry.extras] AWS = ["boto3"] diff --git a/tests/unittests/clients/__init__.py b/tests/unittests/clients/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/__init__.py b/tests/unittests/clients/studio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/resources/__init__.py b/tests/unittests/clients/studio/resources/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py new file mode 100644 index 00000000..c4b38e0f --- /dev/null +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -0,0 +1,122 @@ +import pytest +from pytest_mock import MockerFixture + +from ai21 import AnswerResponse, ChatResponse, CompletionsResponse +from ai21.ai21_studio_client import AI21StudioClient +from ai21.clients.studio.resources.studio_answer import StudioAnswer +from ai21.clients.studio.resources.studio_chat import StudioChat +from ai21.clients.studio.resources.studio_completion import StudioCompletion +from ai21.resources.responses.chat_response import ChatOutput, FinishReason +from ai21.resources.responses.completion_response import Prompt, Completion, CompletionData, CompletionFinishReason + + +@pytest.fixture +def mock_ai21_studio_client(mocker: MockerFixture) -> AI21StudioClient: + return mocker.MagicMock(spec=AI21StudioClient) + + +def get_studio_answer(): + _DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" + _DUMMY_QUESTION = "What is the answer?" + + return ( + StudioAnswer, + {"context": _DUMMY_CONTEXT, "question": _DUMMY_QUESTION}, + "answer", + { + "answerLength": None, + "context": _DUMMY_CONTEXT, + "mode": None, + "question": _DUMMY_QUESTION, + }, + AnswerResponse(id="some-id", answer_in_context=True, answer="42"), + ) + + +def get_studio_chat(): + _DUMMY_MODEL = "dummy-chat-model" + _DUMMY_MESSAGES = [ + { + "text": "Hello, I need help with a signup process.", + "role": "user", + "name": "Alice", + }, + { + "text": "Hi Alice, I can help you with that. What seems to be the problem?", + "role": "assistant", + "name": "Bob", + }, + ] + _DUMMY_SYSTEM = "You're a support engineer in a SaaS company" + + return ( + StudioChat, + {"model": _DUMMY_MODEL, "messages": _DUMMY_MESSAGES, "system": _DUMMY_SYSTEM}, + f"{_DUMMY_MODEL}/chat", + { + "model": _DUMMY_MODEL, + "system": _DUMMY_SYSTEM, + "messages": _DUMMY_MESSAGES, + "temperature": 0.7, + "maxTokens": 300, + "minTokens": 0, + "numResults": 1, + "topP": 1.0, + "topKReturn": 0, + "stopSequences": None, + "frequencyPenalty": None, + "presencePenalty": None, + "countPenalty": None, + }, + ChatResponse( + outputs=[ + ChatOutput( + text="Hello, I need help with a signup process.", + role="user", + finish_reason=FinishReason(reason="dummy_reason", length=1, sequence="1"), + ) + ] + ), + ) + + +def get_studio_completion(): + _DUMMY_MODEL = "dummy-completion-model" + _DUMMY_PROMPT = "dummy-prompt" + + return ( + StudioCompletion, + {"model": _DUMMY_MODEL, "prompt": _DUMMY_PROMPT}, + f"{_DUMMY_MODEL}/complete", + { + "model": _DUMMY_MODEL, + "prompt": _DUMMY_PROMPT, + "temperature": 0.7, + "maxTokens": None, + "minTokens": 0, + "epoch": None, + "numResults": 1, + "topP": 1, + "customModel": None, + "experimentalModel": False, + "topKReturn": 0, + "stopSequences": [], + "frequencyPenalty": None, + "presencePenalty": None, + "countPenalty": None, + }, + CompletionsResponse( + id="some-id", + completions=[ + Completion( + data=CompletionData(text="dummy-completion", tokens=[]), + finish_reason=CompletionFinishReason(reason="dummy_reason", length=1), + ) + ], + prompt=Prompt(text="dummy-prompt"), + ), + ) + + +def get_studio_custom_model(): + pass diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py new file mode 100644 index 00000000..f9051c35 --- /dev/null +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -0,0 +1,80 @@ +from typing import TypeVar, Callable + +import pytest + +from ai21 import AnswerResponse +from ai21.ai21_studio_client import AI21StudioClient +from ai21.clients.studio.resources.studio_answer import StudioAnswer +from ai21.resources.studio_resource import StudioResource +from tests.unittests.clients.studio.resources.conftest import get_studio_answer, get_studio_chat, get_studio_completion + +_BASE_URL = "https://test.api.ai21.com/studio/v1" +_DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" +_DUMMY_QUESTION = "What is the answer?" + +T = TypeVar("T", bound=StudioResource) + + +class TestStudioResources: + @pytest.mark.parametrize( + ids=[ + "studio_answer", + "studio_chat", + "studio_completion", + ], + argnames=["studio_resource", "function_body", "url_suffix", "expected_body", "expected_response"], + argvalues=[ + (get_studio_answer()), + (get_studio_chat()), + (get_studio_completion()), + ], + ) + def test__create__should_return_answer_response( + self, + studio_resource: Callable[[AI21StudioClient], T], + function_body, + url_suffix: str, + expected_body, + expected_response, + mock_ai21_studio_client: AI21StudioClient, + ): + mock_ai21_studio_client.execute_http_request.return_value = expected_response.to_dict() + mock_ai21_studio_client.get_base_url.return_value = _BASE_URL + + resource = studio_resource(mock_ai21_studio_client) + + actual_response = resource.create( + **function_body, + ) + + assert actual_response == expected_response + mock_ai21_studio_client.execute_http_request.assert_called_with( + method="POST", + url=f"{_BASE_URL}/{url_suffix}", + params=expected_body, + files=None, + ) + + def test__create__when_pass_kwargs__should_not_pass_to_request(self, mock_ai21_studio_client: AI21StudioClient): + expected_answer = AnswerResponse(id="some-id", answer_in_context=True, answer="42") + mock_ai21_studio_client.execute_http_request.return_value = expected_answer.to_dict() + mock_ai21_studio_client.get_base_url.return_value = _BASE_URL + studio_answer = StudioAnswer(mock_ai21_studio_client) + + studio_answer.create( + context=_DUMMY_CONTEXT, + question=_DUMMY_QUESTION, + some_dummy_kwargs="some_dummy_value", + ) + + mock_ai21_studio_client.execute_http_request.assert_called_with( + method="POST", + url=_BASE_URL + "/answer", + params={ + "answerLength": None, + "context": _DUMMY_CONTEXT, + "mode": None, + "question": _DUMMY_QUESTION, + }, + files=None, + ) diff --git a/tests/unittests/services/__init__.py b/tests/unittests/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py new file mode 100644 index 00000000..8f2a4fad --- /dev/null +++ b/tests/unittests/services/test_sagemaker.py @@ -0,0 +1,28 @@ +import pytest +from pytest_mock import MockerFixture + +from ai21 import SageMaker +from ai21.ai21_studio_client import AI21StudioClient +from unittest.mock import patch + + +@pytest.fixture +def mock_ai21_studio_client(mocker: MockerFixture): + return mocker.patch.object( + AI21StudioClient, + "execute_http_request", + return_value={ + "arn": "some-model-package-id1", + "versions": ["1.0.0", "1.0.1"], + }, + ) + + +class TestSageMakerService: + def test__get_model_package_arn__should_return_model_package_arn(self, mocker, mock_ai21_studio_client): + with patch("ai21.ai21_studio_client.AI21StudioClient"): + expected_model_package_arn = "some-model-package-id1" + + actual_model_package_arn = SageMaker.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == expected_model_package_arn From bdac3c4b3c6d6c57588b93b4485754603aae8161 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 16:34:25 +0200 Subject: [PATCH 04/13] fix: rename var --- ai21/http_client.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ai21/http_client.py b/ai21/http_client.py index 7f9b0286..9b75126c 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -56,10 +56,10 @@ def requests_retry_session(session, retries=0): class HttpClient: def __init__(self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None): - self.timeout_sec = timeout_sec if timeout_sec is not None else DEFAULT_TIMEOUT_SEC - self.num_retries = num_retries if num_retries is not None else DEFAULT_NUM_RETRIES - self.headers = headers if headers is not None else {} - self.apply_retry_policy = self.num_retries > 0 + self._timeout_sec = timeout_sec or DEFAULT_TIMEOUT_SEC + self._num_retries = num_retries or DEFAULT_NUM_RETRIES + self._headers = headers or {} + self._apply_retry_policy = self._num_retries > 0 def execute_http_request( self, @@ -70,12 +70,12 @@ def execute_http_request( auth=None, ): session = ( - requests_retry_session(requests.Session(), retries=self.num_retries) - if self.apply_retry_policy + requests_retry_session(requests.Session(), retries=self._num_retries) + if self._apply_retry_policy else requests.Session() ) - timeout = self.timeout_sec - headers = self.headers + timeout = self._timeout_sec + headers = self._headers data = json.dumps(params).encode() logger.info(f"Calling {method} {url} {headers} {data}") try: @@ -112,7 +112,7 @@ def execute_http_request( raise connection_error except RetryError as retry_error: logger.error( - f"Calling {method} {url} failed with RetryError after {self.num_retries} attempts: {retry_error}" + f"Calling {method} {url} failed with RetryError after {self._num_retries} attempts: {retry_error}" ) raise retry_error except Exception as exception: From d1f8be104d8a5f85b2fb4a7a67d26391aa36f934 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 16:52:10 +0200 Subject: [PATCH 05/13] test: Added ai21 studio client tsts --- ai21/ai21_studio_client.py | 8 ++-- ai21/clients/studio/ai21_client.py | 28 +++++++------- tests/test_ai21_studio_client.py | 60 ++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 18 deletions(-) create mode 100644 tests/test_ai21_studio_client.py diff --git a/ai21/ai21_studio_client.py b/ai21/ai21_studio_client.py index 0ceef023..6b6b5543 100644 --- a/ai21/ai21_studio_client.py +++ b/ai21/ai21_studio_client.py @@ -17,6 +17,7 @@ def __init__( timeout_sec: Optional[int] = None, num_retries: Optional[int] = None, organization: Optional[str] = None, + application: Optional[str] = None, via: Optional[str] = None, env_config: _AI21EnvConfig = AI21EnvConfig, ): @@ -32,12 +33,11 @@ def __init__( self._timeout_sec = timeout_sec or self._env_config.timeout_sec self._num_retries = num_retries or self._env_config.num_retries self._organization = organization or self._env_config.organization - self._application = self._env_config.application + self._application = application or self._env_config.application self._via = via headers = self._build_headers(passed_headers=headers) - - self.http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) + self._http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: headers = { @@ -68,7 +68,7 @@ def _build_user_agent(self) -> str: return user_agent def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None): - return self.http_client.execute_http_request(method=method, url=url, params=params, files=files) + return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 025a3bab..7bca5cad 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -35,7 +35,7 @@ def __init__( via: Optional[str] = None, **kwargs, ): - studio_client = AI21StudioClient( + self._studio_client = AI21StudioClient( api_key=api_key, api_host=api_host, headers=headers, @@ -43,19 +43,19 @@ def __init__( num_retries=num_retries, via=via, ) - self.completion = StudioCompletion(studio_client) - self.chat = StudioChat(studio_client) - self.summarize = StudioSummarize(studio_client) - self.embed = StudioEmbed(studio_client) - self.gec = StudioGEC(studio_client) - self.improvements = StudioImprovements(studio_client) - self.paraphrase = StudioParaphrase(studio_client) - self.summarize_by_segment = StudioSummarizeBySegment(studio_client) - self.custom_model = StudioCustomModel(studio_client) - self.dataset = StudioDataset(studio_client) - self.answer = StudioAnswer(studio_client) - self.library = StudioLibrary(studio_client) - self.segmentation = StudioSegmentation(studio_client) + self.completion = StudioCompletion(self._studio_client) + self.chat = StudioChat(self._studio_client) + self.summarize = StudioSummarize(self._studio_client) + self.embed = StudioEmbed(self._studio_client) + self.gec = StudioGEC(self._studio_client) + self.improvements = StudioImprovements(self._studio_client) + self.paraphrase = StudioParaphrase(self._studio_client) + self.summarize_by_segment = StudioSummarizeBySegment(self._studio_client) + self.custom_model = StudioCustomModel(self._studio_client) + self.dataset = StudioDataset(self._studio_client) + self.answer = StudioAnswer(self._studio_client) + self.library = StudioLibrary(self._studio_clienself._studio_clientt) + self.segmentation = StudioSegmentation() def count_tokens(self, text: str) -> int: # We might want to cache the tokenizer instance within the class diff --git a/tests/test_ai21_studio_client.py b/tests/test_ai21_studio_client.py new file mode 100644 index 00000000..aaecfc99 --- /dev/null +++ b/tests/test_ai21_studio_client.py @@ -0,0 +1,60 @@ +from typing import Optional + +import pytest + +from ai21.ai21_studio_client import AI21StudioClient +from ai21.version import VERSION + +_DUMMY_API_KEY = "dummy_key" + + +class TestAI21StudioClient: + @pytest.mark.parametrize( + ids=[ + "when_pass_only_via__should_include_via_in_user_agent", + "when_pass_only_application__should_include_application_in_user_agent", + "when_pass_organization__should_include_organization_in_user_agent", + "when_pass_all_user_agent_relevant_params__should_include_them_in_user_agent", + ], + argnames=["via", "application", "organization", "expected_user_agent"], + argvalues=[ + ("langchain", None, None, f"ai21 studio SDK {VERSION} via: langchain"), + (None, "studio", None, f"ai21 studio SDK {VERSION} application: studio"), + (None, None, "ai21", f"ai21 studio SDK {VERSION} organization: ai21"), + ( + "langchain", + "studio", + "ai21", + f"ai21 studio SDK {VERSION} organization: ai21 application: studio via: langchain", + ), + ], + ) + def test__build_headers__user_agent( + self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str + ): + client = AI21StudioClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) + assert client._http_client._headers["User-Agent"] == expected_user_agent + + def test__build_headers__authorization(self): + client = AI21StudioClient(api_key=_DUMMY_API_KEY) + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + def test__build_headers__when_pass_headers__should_append(self): + client = AI21StudioClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) + assert client._http_client._headers["foo"] == "bar" + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + @pytest.mark.parametrize( + ids=[ + "when_api_host_is_not_set__should_return_default", + "when_api_host_is_set__should_return_set_value", + ], + argnames=["api_host", "expected_api_host"], + argvalues=[ + (None, "https://api.ai21.com/studio/v1"), + ("http://test_host", "http://test_host/studio/v1"), + ], + ) + def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): + client = AI21StudioClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") + assert client.get_base_url() == expected_api_host From 287b4ebe996a4b7fc7a76455e58665a39eca8574 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 21 Dec 2023 17:07:46 +0200 Subject: [PATCH 06/13] fix: rename files --- ...1_studio_client.py => ai21_http_client.py} | 2 +- ai21/clients/studio/ai21_client.py | 30 +++++++++---------- .../studio/resources/studio_library.py | 4 +-- ai21/resources/studio_resource.py | 4 +-- ai21/services/sagemaker.py | 6 ++-- tests/test_ai21_studio_client.py | 10 +++---- .../clients/studio/resources/conftest.py | 6 ++-- .../studio/resources/test_studio_resources.py | 8 ++--- tests/unittests/services/test_sagemaker.py | 4 +-- 9 files changed, 37 insertions(+), 37 deletions(-) rename ai21/{ai21_studio_client.py => ai21_http_client.py} (99%) diff --git a/ai21/ai21_studio_client.py b/ai21/ai21_http_client.py similarity index 99% rename from ai21/ai21_studio_client.py rename to ai21/ai21_http_client.py index 6b6b5543..df845f39 100644 --- a/ai21/ai21_studio_client.py +++ b/ai21/ai21_http_client.py @@ -6,7 +6,7 @@ from ai21.version import VERSION -class AI21StudioClient: +class AI21HTTPClient: def __init__( self, *, diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 7bca5cad..e308bc0a 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -1,6 +1,6 @@ from typing import Optional, Any, Dict -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion @@ -35,7 +35,7 @@ def __init__( via: Optional[str] = None, **kwargs, ): - self._studio_client = AI21StudioClient( + self._http_client = AI21HTTPClient( api_key=api_key, api_host=api_host, headers=headers, @@ -43,19 +43,19 @@ def __init__( num_retries=num_retries, via=via, ) - self.completion = StudioCompletion(self._studio_client) - self.chat = StudioChat(self._studio_client) - self.summarize = StudioSummarize(self._studio_client) - self.embed = StudioEmbed(self._studio_client) - self.gec = StudioGEC(self._studio_client) - self.improvements = StudioImprovements(self._studio_client) - self.paraphrase = StudioParaphrase(self._studio_client) - self.summarize_by_segment = StudioSummarizeBySegment(self._studio_client) - self.custom_model = StudioCustomModel(self._studio_client) - self.dataset = StudioDataset(self._studio_client) - self.answer = StudioAnswer(self._studio_client) - self.library = StudioLibrary(self._studio_clienself._studio_clientt) - self.segmentation = StudioSegmentation() + self.completion = StudioCompletion(self._http_client) + self.chat = StudioChat(self._http_client) + self.summarize = StudioSummarize(self._http_client) + self.embed = StudioEmbed(self._http_client) + self.gec = StudioGEC(self._http_client) + self.improvements = StudioImprovements(self._http_client) + self.paraphrase = StudioParaphrase(self._http_client) + self.summarize_by_segment = StudioSummarizeBySegment(self._http_client) + self.custom_model = StudioCustomModel(self._http_client) + self.dataset = StudioDataset(self._http_client) + self.answer = StudioAnswer(self._http_client) + self.library = StudioLibrary(self._http_client) + self.segmentation = StudioSegmentation(self._http_client) def count_tokens(self, text: str) -> int: # We might want to cache the tokenizer instance within the class diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index ae785a85..42daedbb 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,6 +1,6 @@ from typing import Optional, List -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.resources.responses.file_response import FileResponse from ai21.resources.responses.library_answer_response import LibraryAnswerResponse from ai21.resources.responses.library_search_response import LibrarySearchResponse @@ -10,7 +10,7 @@ class StudioLibrary(StudioResource): _module_name = "library/files" - def __init__(self, client: AI21StudioClient): + def __init__(self, client: AI21HTTPClient): super().__init__(client) self.files = LibraryFiles(client) self.search = LibrarySearch(client) diff --git a/ai21/resources/studio_resource.py b/ai21/resources/studio_resource.py index a467bd94..ee23341c 100644 --- a/ai21/resources/studio_resource.py +++ b/ai21/resources/studio_resource.py @@ -3,11 +3,11 @@ from abc import ABC from typing import Any, Dict, Optional -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient class StudioResource(ABC): - def __init__(self, client: AI21StudioClient): + def __init__(self, client: AI21HTTPClient): self._client = client def _post( diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index 2d9c6469..a4c26d3e 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -1,6 +1,6 @@ from typing import List -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.sagemaker.constants import ( SAGEMAKER_MODEL_PACKAGE_NAMES, ) @@ -18,7 +18,7 @@ class SageMaker: def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21StudioClient() + client = AI21HTTPClient() client.execute_http_request( method="POST", @@ -40,7 +40,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE @classmethod def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21StudioClient() + client = AI21HTTPClient() response = client.execute_http_request( method="POST", diff --git a/tests/test_ai21_studio_client.py b/tests/test_ai21_studio_client.py index aaecfc99..bbbe2ec6 100644 --- a/tests/test_ai21_studio_client.py +++ b/tests/test_ai21_studio_client.py @@ -2,7 +2,7 @@ import pytest -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.version import VERSION _DUMMY_API_KEY = "dummy_key" @@ -32,15 +32,15 @@ class TestAI21StudioClient: def test__build_headers__user_agent( self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str ): - client = AI21StudioClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) assert client._http_client._headers["User-Agent"] == expected_user_agent def test__build_headers__authorization(self): - client = AI21StudioClient(api_key=_DUMMY_API_KEY) + client = AI21HTTPClient(api_key=_DUMMY_API_KEY) assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" def test__build_headers__when_pass_headers__should_append(self): - client = AI21StudioClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) assert client._http_client._headers["foo"] == "bar" assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" @@ -56,5 +56,5 @@ def test__build_headers__when_pass_headers__should_append(self): ], ) def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): - client = AI21StudioClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") assert client.get_base_url() == expected_api_host diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index c4b38e0f..1a921fef 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -2,7 +2,7 @@ from pytest_mock import MockerFixture from ai21 import AnswerResponse, ChatResponse, CompletionsResponse -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion @@ -11,8 +11,8 @@ @pytest.fixture -def mock_ai21_studio_client(mocker: MockerFixture) -> AI21StudioClient: - return mocker.MagicMock(spec=AI21StudioClient) +def mock_ai21_studio_client(mocker: MockerFixture) -> AI21HTTPClient: + return mocker.MagicMock(spec=AI21HTTPClient) def get_studio_answer(): diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index f9051c35..0e4de3af 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -3,7 +3,7 @@ import pytest from ai21 import AnswerResponse -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.resources.studio_resource import StudioResource from tests.unittests.clients.studio.resources.conftest import get_studio_answer, get_studio_chat, get_studio_completion @@ -31,12 +31,12 @@ class TestStudioResources: ) def test__create__should_return_answer_response( self, - studio_resource: Callable[[AI21StudioClient], T], + studio_resource: Callable[[AI21HTTPClient], T], function_body, url_suffix: str, expected_body, expected_response, - mock_ai21_studio_client: AI21StudioClient, + mock_ai21_studio_client: AI21HTTPClient, ): mock_ai21_studio_client.execute_http_request.return_value = expected_response.to_dict() mock_ai21_studio_client.get_base_url.return_value = _BASE_URL @@ -55,7 +55,7 @@ def test__create__should_return_answer_response( files=None, ) - def test__create__when_pass_kwargs__should_not_pass_to_request(self, mock_ai21_studio_client: AI21StudioClient): + def test__create__when_pass_kwargs__should_not_pass_to_request(self, mock_ai21_studio_client: AI21HTTPClient): expected_answer = AnswerResponse(id="some-id", answer_in_context=True, answer="42") mock_ai21_studio_client.execute_http_request.return_value = expected_answer.to_dict() mock_ai21_studio_client.get_base_url.return_value = _BASE_URL diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index 8f2a4fad..531fb5ef 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -2,14 +2,14 @@ from pytest_mock import MockerFixture from ai21 import SageMaker -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from unittest.mock import patch @pytest.fixture def mock_ai21_studio_client(mocker: MockerFixture): return mocker.patch.object( - AI21StudioClient, + AI21HTTPClient, "execute_http_request", return_value={ "arn": "some-model-package-id1", From 8697aeeed1b6ecdcbdc4062e0cbb07f2aa3a7e6b Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 11:55:52 +0200 Subject: [PATCH 07/13] fix: Added types --- ai21/ai21_http_client.py | 12 ++++++++++-- ai21/clients/studio/ai21_client.py | 3 +++ ai21/http_client.py | 7 ++++--- ai21/resources/studio_resource.py | 3 ++- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index df845f39..b58fbb9f 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,3 +1,4 @@ +import io from typing import Optional, Dict, Any from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig @@ -19,6 +20,7 @@ def __init__( organization: Optional[str] = None, application: Optional[str] = None, via: Optional[str] = None, + http_client: Optional[HttpClient] = None, env_config: _AI21EnvConfig = AI21EnvConfig, ): self._env_config = env_config @@ -37,7 +39,7 @@ def __init__( self._via = via headers = self._build_headers(passed_headers=headers) - self._http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) + self._http_client = http_client or HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: headers = { @@ -67,7 +69,13 @@ def _build_user_agent(self) -> str: return user_agent - def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None): + def execute_http_request( + self, + method: str, + url: str, + params: Optional[Dict] = None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, + ): return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) def get_base_url(self) -> str: diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index e308bc0a..5f6cee06 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -14,6 +14,7 @@ from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation from ai21.clients.studio.resources.studio_summarize import StudioSummarize from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment +from ai21.http_client import HttpClient from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer from ai21.tokenizers.factory import get_tokenizer @@ -33,6 +34,7 @@ def __init__( timeout_sec: Optional[float] = None, num_retries: Optional[int] = None, via: Optional[str] = None, + http_client: Optional[HttpClient] = None, **kwargs, ): self._http_client = AI21HTTPClient( @@ -42,6 +44,7 @@ def __init__( timeout_sec=timeout_sec, num_retries=num_retries, via=via, + http_client=http_client, ) self.completion = StudioCompletion(self._http_client) self.chat = StudioChat(self._http_client) diff --git a/ai21/http_client.py b/ai21/http_client.py index 9b75126c..e5d9620a 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,5 +1,6 @@ +import io import json -from typing import Optional, Dict +from typing import Optional, Dict, Callable, Tuple, Union import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -66,8 +67,8 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, - files=None, - auth=None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, + auth: Optional[Union[Tuple, Callable]] = None, ): session = ( requests_retry_session(requests.Session(), retries=self._num_retries) diff --git a/ai21/resources/studio_resource.py b/ai21/resources/studio_resource.py index ee23341c..7752be91 100644 --- a/ai21/resources/studio_resource.py +++ b/ai21/resources/studio_resource.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from abc import ABC from typing import Any, Dict, Optional @@ -14,7 +15,7 @@ def _post( self, url: str, body: Dict[str, Any], - files: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, ) -> Dict[str, Any]: return self._client.execute_http_request( method="POST", From 0b3a1f6f91294114b8569c23d31a054efdb3653e Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 17:30:56 +0200 Subject: [PATCH 08/13] test: added test to http --- ai21/ai21_http_client.py | 7 +++- ai21/http_client.py | 17 ++++++--- tests/conftest.py | 12 ++++++ ...dio_client.py => test_ai21_http_client.py} | 37 +++++++++++++++++++ 4 files changed, 67 insertions(+), 6 deletions(-) create mode 100644 tests/conftest.py rename tests/{test_ai21_studio_client.py => test_ai21_http_client.py} (67%) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index b58fbb9f..7f6fb66d 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,6 +1,8 @@ import io from typing import Optional, Dict, Any +import requests + from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyException from ai21.http_client import HttpClient @@ -75,8 +77,11 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, + session: Optional[requests.Session] = None, ): - return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) + return self._http_client.execute_http_request( + method=method, url=url, params=params, files=files, session=session + ) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/http_client.py b/ai21/http_client.py index e5d9620a..b8d8a9f0 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -69,12 +69,9 @@ def execute_http_request( params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, auth: Optional[Union[Tuple, Callable]] = None, + session: Optional[requests.Session] = None, ): - session = ( - requests_retry_session(requests.Session(), retries=self._num_retries) - if self._apply_retry_policy - else requests.Session() - ) + session = self._init_session(session) timeout = self._timeout_sec headers = self._headers data = json.dumps(params).encode() @@ -125,3 +122,13 @@ def execute_http_request( handle_non_success_response(response.status_code, response.text) return response.json() + + def _init_session(self, session: Optional[requests.Session]) -> requests.Session: + if session is not None: + return session + + return ( + requests_retry_session(requests.Session(), retries=self._num_retries) + if self._apply_retry_policy + else requests.Session() + ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..02e4d467 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import pytest +import requests + + +@pytest.fixture +def dummy_api_host() -> str: + return "http://test_host" + + +@pytest.fixture +def mock_requests_session(mocker) -> requests.Session: + return mocker.Mock(spec=requests.Session) diff --git a/tests/test_ai21_studio_client.py b/tests/test_ai21_http_client.py similarity index 67% rename from tests/test_ai21_studio_client.py rename to tests/test_ai21_http_client.py index bbbe2ec6..eb48b451 100644 --- a/tests/test_ai21_studio_client.py +++ b/tests/test_ai21_http_client.py @@ -1,6 +1,7 @@ from typing import Optional import pytest +import requests from ai21.ai21_http_client import AI21HTTPClient from ai21.version import VERSION @@ -8,6 +9,15 @@ _DUMMY_API_KEY = "dummy_key" +class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + class TestAI21StudioClient: @pytest.mark.parametrize( ids=[ @@ -58,3 +68,30 @@ def test__build_headers__when_pass_headers__should_append(self): def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") assert client.get_base_url() == expected_api_host + + def test__execute_http_request__when_making_request__should_send_appropriate_parameters( + self, + dummy_api_host: str, + mock_requests_session: requests.Session, + ): + method = "GET" + url = "test_url" + params = {"foo": "bar"} + response_json = {"test_key": "test_value"} + headers = { + "Authorization": "Bearer dummy_key", + "Content-Type": "application/json", + "User-Agent": f"ai21 studio SDK {VERSION}", + } + + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") + mock_requests_session.request.return_value = MockResponse(response_json, 200) + response = client.execute_http_request(method=method, url=url, params=params, session=mock_requests_session) + assert response == response_json + mock_requests_session.request.assert_called_once_with( + method, + url, + headers=headers, + timeout=300, + params=params, + ) From c7fa29f4bf96b4449d56faa51fe7637ee1eb51b2 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 18:10:43 +0200 Subject: [PATCH 09/13] fix: removed unnecessary auth param --- ai21/http_client.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/ai21/http_client.py b/ai21/http_client.py index b8d8a9f0..ecf7bd19 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,6 +1,6 @@ import io import json -from typing import Optional, Dict, Callable, Tuple, Union +from typing import Optional, Dict import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -68,7 +68,6 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, - auth: Optional[Union[Tuple, Callable]] = None, session: Optional[requests.Session] = None, ): session = self._init_session(session) @@ -79,8 +78,8 @@ def execute_http_request( try: if method == "GET": response = session.request( - method, - url, + method=method, + url=url, headers=headers, timeout=timeout, params=params, @@ -95,16 +94,15 @@ def execute_http_request( "Content-Type" ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload response = session.request( - method, - url, + method=method, + url=url, headers=headers, data=params, files=files, timeout=timeout, - auth=auth, ) else: - response = session.request(method, url, headers=headers, data=data, timeout=timeout, auth=auth) + response = session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) except ConnectionError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error From 37a57cd6863da54c4152fbad2af2d5980e903351 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Dec 2023 18:10:52 +0200 Subject: [PATCH 10/13] test: Added tests --- tests/test_ai21_http_client.py | 71 ++++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 17 deletions(-) diff --git a/tests/test_ai21_http_client.py b/tests/test_ai21_http_client.py index eb48b451..942c192a 100644 --- a/tests/test_ai21_http_client.py +++ b/tests/test_ai21_http_client.py @@ -7,6 +7,16 @@ from ai21.version import VERSION _DUMMY_API_KEY = "dummy_key" +_EXPECTED_GET_HEADERS = { + "Authorization": "Bearer dummy_key", + "Content-Type": "application/json", + "User-Agent": f"ai21 studio SDK {VERSION}", +} + +_EXPECTED_POST_FILE_HEADERS = { + "Authorization": "Bearer dummy_key", + "User-Agent": f"ai21 studio SDK {VERSION}", +} class MockResponse: @@ -69,29 +79,56 @@ def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") assert client.get_base_url() == expected_api_host - def test__execute_http_request__when_making_request__should_send_appropriate_parameters( + @pytest.mark.parametrize( + ids=[ + "when_making_request__should_send_appropriate_parameters", + "when_making_request_with_files__should_send_appropriate_post_request", + ], + argnames=["params", "headers"], + argvalues=[ + ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), + ( + {"method": "POST", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}}, + _EXPECTED_POST_FILE_HEADERS, + ), + ], + ) + def test__execute_http_request__( self, + params, + headers, dummy_api_host: str, mock_requests_session: requests.Session, ): - method = "GET" - url = "test_url" - params = {"foo": "bar"} response_json = {"test_key": "test_value"} - headers = { - "Authorization": "Bearer dummy_key", - "Content-Type": "application/json", - "User-Agent": f"ai21 studio SDK {VERSION}", - } client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") mock_requests_session.request.return_value = MockResponse(response_json, 200) - response = client.execute_http_request(method=method, url=url, params=params, session=mock_requests_session) + response = client.execute_http_request(session=mock_requests_session, **params) assert response == response_json - mock_requests_session.request.assert_called_once_with( - method, - url, - headers=headers, - timeout=300, - params=params, - ) + + if "files" in params: + # We split it because when calling requests with "files", "params" is turned into "data" + mock_requests_session.request.assert_called_once_with( + timeout=300, + headers=headers, + files=params["files"], + data=params["params"], + url=params["url"], + method=params["method"], + ) + else: + mock_requests_session.request.assert_called_once_with(timeout=300, headers=headers, **params) + + def test__execute_http_request__when_files_with_put_method__should_raise_value_error( + self, + dummy_api_host: str, + mock_requests_session: requests.Session, + ): + response_json = {"test_key": "test_value"} + + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") + mock_requests_session.request.return_value = MockResponse(response_json, 200) + with pytest.raises(ValueError): + params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} + client.execute_http_request(session=mock_requests_session, **params) From 0b4c6befd7a5d857537f0b867bd34b3f1aeac1ab Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Dec 2023 11:44:31 +0200 Subject: [PATCH 11/13] test: Added sagemaker --- ai21/services/sagemaker.py | 13 +++-- tests/unittests/services/sagemaker_stub.py | 12 +++++ tests/unittests/services/test_sagemaker.py | 56 ++++++++++++++-------- 3 files changed, 57 insertions(+), 24 deletions(-) create mode 100644 tests/unittests/services/sagemaker_stub.py diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index a4c26d3e..b1387622 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -18,9 +18,9 @@ class SageMaker: def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21HTTPClient() + client = cls._create_ai21_http_client() - client.execute_http_request( + response = client.execute_http_request( method="POST", url=f"{client.get_base_url()}/{_GET_ARN_ENDPOINT}", params={ @@ -30,7 +30,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE }, ) - arn = ["arn"] + arn = response["arn"] if not arn: raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version) @@ -40,7 +40,8 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE @classmethod def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21HTTPClient() + + client = cls._create_ai21_http_client() response = client.execute_http_request( method="POST", @@ -53,6 +54,10 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: return response["versions"] + @classmethod + def _create_ai21_http_client(cls) -> AI21HTTPClient: + return AI21HTTPClient() + def _assert_model_package_exists(model_name, region): if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: diff --git a/tests/unittests/services/sagemaker_stub.py b/tests/unittests/services/sagemaker_stub.py new file mode 100644 index 00000000..16cd98c2 --- /dev/null +++ b/tests/unittests/services/sagemaker_stub.py @@ -0,0 +1,12 @@ +from unittest.mock import Mock + +from ai21 import SageMaker +from ai21.ai21_http_client import AI21HTTPClient + + +class SageMakerStub(SageMaker): + ai21_http_client = Mock(spec=AI21HTTPClient) + + @classmethod + def _create_ai21_http_client(cls): + return cls.ai21_http_client diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index 531fb5ef..a92c23fe 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,28 +1,44 @@ import pytest -from pytest_mock import MockerFixture -from ai21 import SageMaker -from ai21.ai21_http_client import AI21HTTPClient -from unittest.mock import patch +from ai21.errors import ModelPackageDoesntExistException +from tests.unittests.services.sagemaker_stub import SageMakerStub - -@pytest.fixture -def mock_ai21_studio_client(mocker: MockerFixture): - return mocker.patch.object( - AI21HTTPClient, - "execute_http_request", - return_value={ - "arn": "some-model-package-id1", - "versions": ["1.0.0", "1.0.1"], - }, - ) +_DUMMY_ARN = "some-model-package-id1" +_DUMMY_VERSIONS = ["1.0.0", "1.0.1"] class TestSageMakerService: - def test__get_model_package_arn__should_return_model_package_arn(self, mocker, mock_ai21_studio_client): - with patch("ai21.ai21_studio_client.AI21StudioClient"): - expected_model_package_arn = "some-model-package-id1" + def test__get_model_package_arn__should_return_model_package_arn(self): + expected_response = { + "arn": _DUMMY_ARN, + "versions": _DUMMY_VERSIONS, + } + SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + + actual_model_package_arn = SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == _DUMMY_ARN + + def test__get_model_package_arn__when_no_arn__should_raise_error(self): + SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} + + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + def test__list_model_package_versions__should_return_model_package_arn(self): + expected_response = { + "versions": _DUMMY_VERSIONS, + } + SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + + actual_model_package_arn = SageMakerStub.list_model_package_versions(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == _DUMMY_VERSIONS - actual_model_package_arn = SageMaker.get_model_package_arn(model_name="j2-mid", region="us-east-1") + def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") - assert actual_model_package_arn == expected_model_package_arn + def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1") From c2f15a1cabfdbccdfaeb6588119b1a42318df7b4 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Dec 2023 13:08:04 +0200 Subject: [PATCH 12/13] test: Created a single session per instance --- ai21/ai21_http_client.py | 20 ++++++++++++----- ai21/http_client.py | 22 +++++++++++++------ tests/{ => unittests}/conftest.py | 0 .../{ => unittests}/test_ai21_http_client.py | 18 ++++++++++----- 4 files changed, 42 insertions(+), 18 deletions(-) rename tests/{ => unittests}/conftest.py (100%) rename tests/{ => unittests}/test_ai21_http_client.py (89%) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 7f6fb66d..9bd3cb82 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,7 +1,6 @@ import io from typing import Optional, Dict, Any -import requests from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyException @@ -41,7 +40,7 @@ def __init__( self._via = via headers = self._build_headers(passed_headers=headers) - self._http_client = http_client or HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) + self._http_client = self._init_http_client(http_client=http_client, headers=headers) def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: headers = { @@ -57,6 +56,18 @@ def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, return headers + def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str, Any]) -> HttpClient: + if http_client is None: + return HttpClient( + timeout_sec=self._timeout_sec, + num_retries=self._num_retries, + headers=headers, + ) + + http_client.add_headers(headers) + + return http_client + def _build_user_agent(self) -> str: user_agent = f"ai21 studio SDK {VERSION}" @@ -77,11 +88,8 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, - session: Optional[requests.Session] = None, ): - return self._http_client.execute_http_request( - method=method, url=url, params=params, files=files, session=session - ) + return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/http_client.py b/ai21/http_client.py index ecf7bd19..00692e07 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,6 +1,6 @@ import io import json -from typing import Optional, Dict +from typing import Optional, Dict, Any import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -56,11 +56,18 @@ def requests_retry_session(session, retries=0): class HttpClient: - def __init__(self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None): + def __init__( + self, + session: Optional[requests.Session] = None, + timeout_sec: int = None, + num_retries: int = None, + headers: Dict = None, + ): self._timeout_sec = timeout_sec or DEFAULT_TIMEOUT_SEC self._num_retries = num_retries or DEFAULT_NUM_RETRIES self._headers = headers or {} self._apply_retry_policy = self._num_retries > 0 + self._session = self._init_session(session) def execute_http_request( self, @@ -68,16 +75,14 @@ def execute_http_request( url: str, params: Optional[Dict] = None, files: Optional[Dict[str, io.TextIOWrapper]] = None, - session: Optional[requests.Session] = None, ): - session = self._init_session(session) timeout = self._timeout_sec headers = self._headers data = json.dumps(params).encode() logger.info(f"Calling {method} {url} {headers} {data}") try: if method == "GET": - response = session.request( + response = self._session.request( method=method, url=url, headers=headers, @@ -93,7 +98,7 @@ def execute_http_request( headers.pop( "Content-Type" ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload - response = session.request( + response = self._session.request( method=method, url=url, headers=headers, @@ -102,7 +107,7 @@ def execute_http_request( timeout=timeout, ) else: - response = session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) + response = self._session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) except ConnectionError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error @@ -130,3 +135,6 @@ def _init_session(self, session: Optional[requests.Session]) -> requests.Session if self._apply_retry_policy else requests.Session() ) + + def add_headers(self, headers: Dict[str, Any]) -> None: + self._headers.update(headers) diff --git a/tests/conftest.py b/tests/unittests/conftest.py similarity index 100% rename from tests/conftest.py rename to tests/unittests/conftest.py diff --git a/tests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py similarity index 89% rename from tests/test_ai21_http_client.py rename to tests/unittests/test_ai21_http_client.py index 942c192a..67c4f197 100644 --- a/tests/test_ai21_http_client.py +++ b/tests/unittests/test_ai21_http_client.py @@ -4,6 +4,7 @@ import requests from ai21.ai21_http_client import AI21HTTPClient +from ai21.http_client import HttpClient from ai21.version import VERSION _DUMMY_API_KEY = "dummy_key" @@ -101,10 +102,14 @@ def test__execute_http_request__( mock_requests_session: requests.Session, ): response_json = {"test_key": "test_value"} - - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") mock_requests_session.request.return_value = MockResponse(response_json, 200) - response = client.execute_http_request(session=mock_requests_session, **params) + + http_client = HttpClient(session=mock_requests_session) + client = AI21HTTPClient( + http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1" + ) + + response = client.execute_http_request(**params) assert response == response_json if "files" in params: @@ -126,9 +131,12 @@ def test__execute_http_request__when_files_with_put_method__should_raise_value_e mock_requests_session: requests.Session, ): response_json = {"test_key": "test_value"} + http_client = HttpClient(session=mock_requests_session) + client = AI21HTTPClient( + http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1" + ) - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") mock_requests_session.request.return_value = MockResponse(response_json, 200) with pytest.raises(ValueError): params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} - client.execute_http_request(session=mock_requests_session, **params) + client.execute_http_request(**params) From e7cf601fa3db489b3f2b15aff2dec497fb4debda Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Dec 2023 13:13:06 +0200 Subject: [PATCH 13/13] ci: removed unnecessary action --- .github/workflows/quality-checks.yml | 27 --------------------------- 1 file changed, 27 deletions(-) delete mode 100644 .github/workflows/quality-checks.yml diff --git a/.github/workflows/quality-checks.yml b/.github/workflows/quality-checks.yml deleted file mode 100644 index ba62006a..00000000 --- a/.github/workflows/quality-checks.yml +++ /dev/null @@ -1,27 +0,0 @@ -# yaml-language-server: $schema=https://json.schemastore.org/github-workflow.json - -name: Quality Checks -concurrency: - group: Quality-Checks-${{ github.head_ref }} - cancel-in-progress: true -on: - pull_request: -jobs: - quality-checks: - runs-on: ubuntu-20.04 - timeout-minutes: 10 - steps: - - name: Checkout - uses: actions/checkout@v3.2.0 - with: - fetch-depth: 0 - - name: Pre-commit - uses: pre-commit/action@v3.0.0 - with: - extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} - # - name: CODEOWNERS validator - # uses: mszostok/codeowners-validator@v0.6.0 - # with: - # checks: files,duppatterns,syntax,owners - # experimental_checks: notowned - # github_access_token: ${{ secrets.GH_PAT_RO }}