diff --git a/.github/workflows/quality-checks.yml b/.github/workflows/quality-checks.yml deleted file mode 100644 index ba62006a..00000000 --- a/.github/workflows/quality-checks.yml +++ /dev/null @@ -1,27 +0,0 @@ -# yaml-language-server: $schema=https://json.schemastore.org/github-workflow.json - -name: Quality Checks -concurrency: - group: Quality-Checks-${{ github.head_ref }} - cancel-in-progress: true -on: - pull_request: -jobs: - quality-checks: - runs-on: ubuntu-20.04 - timeout-minutes: 10 - steps: - - name: Checkout - uses: actions/checkout@v3.2.0 - with: - fetch-depth: 0 - - name: Pre-commit - uses: pre-commit/action@v3.0.0 - with: - extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} - # - name: CODEOWNERS validator - # uses: mszostok/codeowners-validator@v0.6.0 - # with: - # checks: files,duppatterns,syntax,owners - # experimental_checks: notowned - # github_access_token: ${{ secrets.GH_PAT_RO }} diff --git a/ai21/ai21_studio_client.py b/ai21/ai21_http_client.py similarity index 70% rename from ai21/ai21_studio_client.py rename to ai21/ai21_http_client.py index 0ceef023..9bd3cb82 100644 --- a/ai21/ai21_studio_client.py +++ b/ai21/ai21_http_client.py @@ -1,12 +1,14 @@ +import io from typing import Optional, Dict, Any + from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyException from ai21.http_client import HttpClient from ai21.version import VERSION -class AI21StudioClient: +class AI21HTTPClient: def __init__( self, *, @@ -17,7 +19,9 @@ def __init__( timeout_sec: Optional[int] = None, num_retries: Optional[int] = None, organization: Optional[str] = None, + application: Optional[str] = None, via: Optional[str] = None, + http_client: Optional[HttpClient] = None, env_config: _AI21EnvConfig = AI21EnvConfig, ): self._env_config = env_config @@ -32,12 +36,11 @@ def __init__( self._timeout_sec = timeout_sec or self._env_config.timeout_sec self._num_retries = num_retries or self._env_config.num_retries self._organization = organization or self._env_config.organization - self._application = self._env_config.application + self._application = application or self._env_config.application self._via = via headers = self._build_headers(passed_headers=headers) - - self.http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) + self._http_client = self._init_http_client(http_client=http_client, headers=headers) def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: headers = { @@ -53,6 +56,18 @@ def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, return headers + def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str, Any]) -> HttpClient: + if http_client is None: + return HttpClient( + timeout_sec=self._timeout_sec, + num_retries=self._num_retries, + headers=headers, + ) + + http_client.add_headers(headers) + + return http_client + def _build_user_agent(self) -> str: user_agent = f"ai21 studio SDK {VERSION}" @@ -67,8 +82,14 @@ def _build_user_agent(self) -> str: return user_agent - def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None): - return self.http_client.execute_http_request(method=method, url=url, params=params, files=files) + def execute_http_request( + self, + method: str, + url: str, + params: Optional[Dict] = None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, + ): + return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 025a3bab..5f6cee06 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -1,6 +1,6 @@ from typing import Optional, Any, Dict -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion @@ -14,6 +14,7 @@ from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation from ai21.clients.studio.resources.studio_summarize import StudioSummarize from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment +from ai21.http_client import HttpClient from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer from ai21.tokenizers.factory import get_tokenizer @@ -33,29 +34,31 @@ def __init__( timeout_sec: Optional[float] = None, num_retries: Optional[int] = None, via: Optional[str] = None, + http_client: Optional[HttpClient] = None, **kwargs, ): - studio_client = AI21StudioClient( + self._http_client = AI21HTTPClient( api_key=api_key, api_host=api_host, headers=headers, timeout_sec=timeout_sec, num_retries=num_retries, via=via, + http_client=http_client, ) - self.completion = StudioCompletion(studio_client) - self.chat = StudioChat(studio_client) - self.summarize = StudioSummarize(studio_client) - self.embed = StudioEmbed(studio_client) - self.gec = StudioGEC(studio_client) - self.improvements = StudioImprovements(studio_client) - self.paraphrase = StudioParaphrase(studio_client) - self.summarize_by_segment = StudioSummarizeBySegment(studio_client) - self.custom_model = StudioCustomModel(studio_client) - self.dataset = StudioDataset(studio_client) - self.answer = StudioAnswer(studio_client) - self.library = StudioLibrary(studio_client) - self.segmentation = StudioSegmentation(studio_client) + self.completion = StudioCompletion(self._http_client) + self.chat = StudioChat(self._http_client) + self.summarize = StudioSummarize(self._http_client) + self.embed = StudioEmbed(self._http_client) + self.gec = StudioGEC(self._http_client) + self.improvements = StudioImprovements(self._http_client) + self.paraphrase = StudioParaphrase(self._http_client) + self.summarize_by_segment = StudioSummarizeBySegment(self._http_client) + self.custom_model = StudioCustomModel(self._http_client) + self.dataset = StudioDataset(self._http_client) + self.answer = StudioAnswer(self._http_client) + self.library = StudioLibrary(self._http_client) + self.segmentation = StudioSegmentation(self._http_client) def count_tokens(self, text: str) -> int: # We might want to cache the tokenizer instance within the class diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index f1dab12b..8fe1bca4 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -39,6 +39,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, ) - url = f"{self._client.get_base_url()}/{model}/{self._module_name}" + url = f"{self._client.get_base_url()}/{model}/{self._MODULE_NAME}" response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 84e179b3..48f85fa7 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -15,7 +15,7 @@ def create( num_results: Optional[int] = 1, min_tokens: Optional[int] = 0, temperature: Optional[float] = 0.7, - top_p: Optional[int] = 1, + top_p: Optional[float] = 1, top_k_return: Optional[int] = 0, custom_model: Optional[str] = None, experimental_mode: bool = False, diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index ae785a85..42daedbb 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,6 +1,6 @@ from typing import Optional, List -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.resources.responses.file_response import FileResponse from ai21.resources.responses.library_answer_response import LibraryAnswerResponse from ai21.resources.responses.library_search_response import LibrarySearchResponse @@ -10,7 +10,7 @@ class StudioLibrary(StudioResource): _module_name = "library/files" - def __init__(self, client: AI21StudioClient): + def __init__(self, client: AI21HTTPClient): super().__init__(client) self.files = LibraryFiles(client) self.search = LibrarySearch(client) diff --git a/ai21/http_client.py b/ai21/http_client.py index 7f9b0286..00692e07 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,5 +1,6 @@ +import io import json -from typing import Optional, Dict +from typing import Optional, Dict, Any import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -55,34 +56,35 @@ def requests_retry_session(session, retries=0): class HttpClient: - def __init__(self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None): - self.timeout_sec = timeout_sec if timeout_sec is not None else DEFAULT_TIMEOUT_SEC - self.num_retries = num_retries if num_retries is not None else DEFAULT_NUM_RETRIES - self.headers = headers if headers is not None else {} - self.apply_retry_policy = self.num_retries > 0 + def __init__( + self, + session: Optional[requests.Session] = None, + timeout_sec: int = None, + num_retries: int = None, + headers: Dict = None, + ): + self._timeout_sec = timeout_sec or DEFAULT_TIMEOUT_SEC + self._num_retries = num_retries or DEFAULT_NUM_RETRIES + self._headers = headers or {} + self._apply_retry_policy = self._num_retries > 0 + self._session = self._init_session(session) def execute_http_request( self, method: str, url: str, params: Optional[Dict] = None, - files=None, - auth=None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, ): - session = ( - requests_retry_session(requests.Session(), retries=self.num_retries) - if self.apply_retry_policy - else requests.Session() - ) - timeout = self.timeout_sec - headers = self.headers + timeout = self._timeout_sec + headers = self._headers data = json.dumps(params).encode() logger.info(f"Calling {method} {url} {headers} {data}") try: if method == "GET": - response = session.request( - method, - url, + response = self._session.request( + method=method, + url=url, headers=headers, timeout=timeout, params=params, @@ -96,23 +98,22 @@ def execute_http_request( headers.pop( "Content-Type" ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload - response = session.request( - method, - url, + response = self._session.request( + method=method, + url=url, headers=headers, data=params, files=files, timeout=timeout, - auth=auth, ) else: - response = session.request(method, url, headers=headers, data=data, timeout=timeout, auth=auth) + response = self._session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) except ConnectionError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error except RetryError as retry_error: logger.error( - f"Calling {method} {url} failed with RetryError after {self.num_retries} attempts: {retry_error}" + f"Calling {method} {url} failed with RetryError after {self._num_retries} attempts: {retry_error}" ) raise retry_error except Exception as exception: @@ -124,3 +125,16 @@ def execute_http_request( handle_non_success_response(response.status_code, response.text) return response.json() + + def _init_session(self, session: Optional[requests.Session]) -> requests.Session: + if session is not None: + return session + + return ( + requests_retry_session(requests.Session(), retries=self._num_retries) + if self._apply_retry_policy + else requests.Session() + ) + + def add_headers(self, headers: Dict[str, Any]) -> None: + self._headers.update(headers) diff --git a/ai21/resources/bases/chat_base.py b/ai21/resources/bases/chat_base.py index f85270ee..e2a67c0d 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/resources/bases/chat_base.py @@ -11,7 +11,7 @@ class Message: class Chat(ABC): - _module_name = "chat" + _MODULE_NAME = "chat" @abstractmethod def create( diff --git a/ai21/resources/studio_resource.py b/ai21/resources/studio_resource.py index a467bd94..7752be91 100644 --- a/ai21/resources/studio_resource.py +++ b/ai21/resources/studio_resource.py @@ -1,20 +1,21 @@ from __future__ import annotations +import io from abc import ABC from typing import Any, Dict, Optional -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient class StudioResource(ABC): - def __init__(self, client: AI21StudioClient): + def __init__(self, client: AI21HTTPClient): self._client = client def _post( self, url: str, body: Dict[str, Any], - files: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, ) -> Dict[str, Any]: return self._client.execute_http_request( method="POST", diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index fa811541..b1387622 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -1,6 +1,6 @@ from typing import List -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.sagemaker.constants import ( SAGEMAKER_MODEL_PACKAGE_NAMES, ) @@ -18,7 +18,7 @@ class SageMaker: def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21StudioClient() + client = cls._create_ai21_http_client() response = client.execute_http_request( method="POST", @@ -40,7 +40,8 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE @classmethod def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21StudioClient() + + client = cls._create_ai21_http_client() response = client.execute_http_request( method="POST", @@ -53,6 +54,10 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: return response["versions"] + @classmethod + def _create_ai21_http_client(cls) -> AI21HTTPClient: + return AI21HTTPClient() + def _assert_model_package_exists(model_name, region): if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: diff --git a/ai21/tokenizers/factory.py b/ai21/tokenizers/factory.py index fd229231..cd728f77 100644 --- a/ai21/tokenizers/factory.py +++ b/ai21/tokenizers/factory.py @@ -16,6 +16,6 @@ def get_tokenizer() -> AI21Tokenizer: global _cached_tokenizer if _cached_tokenizer is None: - _cached_tokenizer = Tokenizer.get_tokenizer() + _cached_tokenizer = AI21Tokenizer(Tokenizer.get_tokenizer()) - return AI21Tokenizer(_cached_tokenizer) + return _cached_tokenizer diff --git a/poetry.lock b/poetry.lock index 334d859e..cd17ee01 100644 --- a/poetry.lock +++ b/poetry.lock @@ -848,6 +848,23 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-mock" +version = "3.12.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, + {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, +] + +[package.dependencies] +pytest = ">=5.0" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -1312,4 +1329,4 @@ aws = ["boto3"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "71ce6369e72538e571ac954b5ebb4e66fa79c1752aa61af336144df577078cc4" +content-hash = "39ea6a4fd93efce593b30be52de954f1d6ab4c2d39745a9541067a5af5f37a21" diff --git a/pyproject.toml b/pyproject.toml index 120b38ae..060d165b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ safety = "*" ruff = "*" python-semantic-release = "^8.5.0" pytest = "^7.4.3" +pytest-mock = "^3.12.0" [tool.poetry.extras] AWS = ["boto3"] diff --git a/tests/unittests/clients/__init__.py b/tests/unittests/clients/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/__init__.py b/tests/unittests/clients/studio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/resources/__init__.py b/tests/unittests/clients/studio/resources/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py new file mode 100644 index 00000000..1a921fef --- /dev/null +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -0,0 +1,122 @@ +import pytest +from pytest_mock import MockerFixture + +from ai21 import AnswerResponse, ChatResponse, CompletionsResponse +from ai21.ai21_http_client import AI21HTTPClient +from ai21.clients.studio.resources.studio_answer import StudioAnswer +from ai21.clients.studio.resources.studio_chat import StudioChat +from ai21.clients.studio.resources.studio_completion import StudioCompletion +from ai21.resources.responses.chat_response import ChatOutput, FinishReason +from ai21.resources.responses.completion_response import Prompt, Completion, CompletionData, CompletionFinishReason + + +@pytest.fixture +def mock_ai21_studio_client(mocker: MockerFixture) -> AI21HTTPClient: + return mocker.MagicMock(spec=AI21HTTPClient) + + +def get_studio_answer(): + _DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" + _DUMMY_QUESTION = "What is the answer?" + + return ( + StudioAnswer, + {"context": _DUMMY_CONTEXT, "question": _DUMMY_QUESTION}, + "answer", + { + "answerLength": None, + "context": _DUMMY_CONTEXT, + "mode": None, + "question": _DUMMY_QUESTION, + }, + AnswerResponse(id="some-id", answer_in_context=True, answer="42"), + ) + + +def get_studio_chat(): + _DUMMY_MODEL = "dummy-chat-model" + _DUMMY_MESSAGES = [ + { + "text": "Hello, I need help with a signup process.", + "role": "user", + "name": "Alice", + }, + { + "text": "Hi Alice, I can help you with that. What seems to be the problem?", + "role": "assistant", + "name": "Bob", + }, + ] + _DUMMY_SYSTEM = "You're a support engineer in a SaaS company" + + return ( + StudioChat, + {"model": _DUMMY_MODEL, "messages": _DUMMY_MESSAGES, "system": _DUMMY_SYSTEM}, + f"{_DUMMY_MODEL}/chat", + { + "model": _DUMMY_MODEL, + "system": _DUMMY_SYSTEM, + "messages": _DUMMY_MESSAGES, + "temperature": 0.7, + "maxTokens": 300, + "minTokens": 0, + "numResults": 1, + "topP": 1.0, + "topKReturn": 0, + "stopSequences": None, + "frequencyPenalty": None, + "presencePenalty": None, + "countPenalty": None, + }, + ChatResponse( + outputs=[ + ChatOutput( + text="Hello, I need help with a signup process.", + role="user", + finish_reason=FinishReason(reason="dummy_reason", length=1, sequence="1"), + ) + ] + ), + ) + + +def get_studio_completion(): + _DUMMY_MODEL = "dummy-completion-model" + _DUMMY_PROMPT = "dummy-prompt" + + return ( + StudioCompletion, + {"model": _DUMMY_MODEL, "prompt": _DUMMY_PROMPT}, + f"{_DUMMY_MODEL}/complete", + { + "model": _DUMMY_MODEL, + "prompt": _DUMMY_PROMPT, + "temperature": 0.7, + "maxTokens": None, + "minTokens": 0, + "epoch": None, + "numResults": 1, + "topP": 1, + "customModel": None, + "experimentalModel": False, + "topKReturn": 0, + "stopSequences": [], + "frequencyPenalty": None, + "presencePenalty": None, + "countPenalty": None, + }, + CompletionsResponse( + id="some-id", + completions=[ + Completion( + data=CompletionData(text="dummy-completion", tokens=[]), + finish_reason=CompletionFinishReason(reason="dummy_reason", length=1), + ) + ], + prompt=Prompt(text="dummy-prompt"), + ), + ) + + +def get_studio_custom_model(): + pass diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py new file mode 100644 index 00000000..0e4de3af --- /dev/null +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -0,0 +1,80 @@ +from typing import TypeVar, Callable + +import pytest + +from ai21 import AnswerResponse +from ai21.ai21_http_client import AI21HTTPClient +from ai21.clients.studio.resources.studio_answer import StudioAnswer +from ai21.resources.studio_resource import StudioResource +from tests.unittests.clients.studio.resources.conftest import get_studio_answer, get_studio_chat, get_studio_completion + +_BASE_URL = "https://test.api.ai21.com/studio/v1" +_DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" +_DUMMY_QUESTION = "What is the answer?" + +T = TypeVar("T", bound=StudioResource) + + +class TestStudioResources: + @pytest.mark.parametrize( + ids=[ + "studio_answer", + "studio_chat", + "studio_completion", + ], + argnames=["studio_resource", "function_body", "url_suffix", "expected_body", "expected_response"], + argvalues=[ + (get_studio_answer()), + (get_studio_chat()), + (get_studio_completion()), + ], + ) + def test__create__should_return_answer_response( + self, + studio_resource: Callable[[AI21HTTPClient], T], + function_body, + url_suffix: str, + expected_body, + expected_response, + mock_ai21_studio_client: AI21HTTPClient, + ): + mock_ai21_studio_client.execute_http_request.return_value = expected_response.to_dict() + mock_ai21_studio_client.get_base_url.return_value = _BASE_URL + + resource = studio_resource(mock_ai21_studio_client) + + actual_response = resource.create( + **function_body, + ) + + assert actual_response == expected_response + mock_ai21_studio_client.execute_http_request.assert_called_with( + method="POST", + url=f"{_BASE_URL}/{url_suffix}", + params=expected_body, + files=None, + ) + + def test__create__when_pass_kwargs__should_not_pass_to_request(self, mock_ai21_studio_client: AI21HTTPClient): + expected_answer = AnswerResponse(id="some-id", answer_in_context=True, answer="42") + mock_ai21_studio_client.execute_http_request.return_value = expected_answer.to_dict() + mock_ai21_studio_client.get_base_url.return_value = _BASE_URL + studio_answer = StudioAnswer(mock_ai21_studio_client) + + studio_answer.create( + context=_DUMMY_CONTEXT, + question=_DUMMY_QUESTION, + some_dummy_kwargs="some_dummy_value", + ) + + mock_ai21_studio_client.execute_http_request.assert_called_with( + method="POST", + url=_BASE_URL + "/answer", + params={ + "answerLength": None, + "context": _DUMMY_CONTEXT, + "mode": None, + "question": _DUMMY_QUESTION, + }, + files=None, + ) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py new file mode 100644 index 00000000..02e4d467 --- /dev/null +++ b/tests/unittests/conftest.py @@ -0,0 +1,12 @@ +import pytest +import requests + + +@pytest.fixture +def dummy_api_host() -> str: + return "http://test_host" + + +@pytest.fixture +def mock_requests_session(mocker) -> requests.Session: + return mocker.Mock(spec=requests.Session) diff --git a/tests/unittests/services/__init__.py b/tests/unittests/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/services/sagemaker_stub.py b/tests/unittests/services/sagemaker_stub.py new file mode 100644 index 00000000..16cd98c2 --- /dev/null +++ b/tests/unittests/services/sagemaker_stub.py @@ -0,0 +1,12 @@ +from unittest.mock import Mock + +from ai21 import SageMaker +from ai21.ai21_http_client import AI21HTTPClient + + +class SageMakerStub(SageMaker): + ai21_http_client = Mock(spec=AI21HTTPClient) + + @classmethod + def _create_ai21_http_client(cls): + return cls.ai21_http_client diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py new file mode 100644 index 00000000..a92c23fe --- /dev/null +++ b/tests/unittests/services/test_sagemaker.py @@ -0,0 +1,44 @@ +import pytest + +from ai21.errors import ModelPackageDoesntExistException +from tests.unittests.services.sagemaker_stub import SageMakerStub + +_DUMMY_ARN = "some-model-package-id1" +_DUMMY_VERSIONS = ["1.0.0", "1.0.1"] + + +class TestSageMakerService: + def test__get_model_package_arn__should_return_model_package_arn(self): + expected_response = { + "arn": _DUMMY_ARN, + "versions": _DUMMY_VERSIONS, + } + SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + + actual_model_package_arn = SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == _DUMMY_ARN + + def test__get_model_package_arn__when_no_arn__should_raise_error(self): + SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} + + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + def test__list_model_package_versions__should_return_model_package_arn(self): + expected_response = { + "versions": _DUMMY_VERSIONS, + } + SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + + actual_model_package_arn = SageMakerStub.list_model_package_versions(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == _DUMMY_VERSIONS + + def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") + + def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1") diff --git a/tests/unittests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py new file mode 100644 index 00000000..67c4f197 --- /dev/null +++ b/tests/unittests/test_ai21_http_client.py @@ -0,0 +1,142 @@ +from typing import Optional + +import pytest +import requests + +from ai21.ai21_http_client import AI21HTTPClient +from ai21.http_client import HttpClient +from ai21.version import VERSION + +_DUMMY_API_KEY = "dummy_key" +_EXPECTED_GET_HEADERS = { + "Authorization": "Bearer dummy_key", + "Content-Type": "application/json", + "User-Agent": f"ai21 studio SDK {VERSION}", +} + +_EXPECTED_POST_FILE_HEADERS = { + "Authorization": "Bearer dummy_key", + "User-Agent": f"ai21 studio SDK {VERSION}", +} + + +class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + +class TestAI21StudioClient: + @pytest.mark.parametrize( + ids=[ + "when_pass_only_via__should_include_via_in_user_agent", + "when_pass_only_application__should_include_application_in_user_agent", + "when_pass_organization__should_include_organization_in_user_agent", + "when_pass_all_user_agent_relevant_params__should_include_them_in_user_agent", + ], + argnames=["via", "application", "organization", "expected_user_agent"], + argvalues=[ + ("langchain", None, None, f"ai21 studio SDK {VERSION} via: langchain"), + (None, "studio", None, f"ai21 studio SDK {VERSION} application: studio"), + (None, None, "ai21", f"ai21 studio SDK {VERSION} organization: ai21"), + ( + "langchain", + "studio", + "ai21", + f"ai21 studio SDK {VERSION} organization: ai21 application: studio via: langchain", + ), + ], + ) + def test__build_headers__user_agent( + self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str + ): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) + assert client._http_client._headers["User-Agent"] == expected_user_agent + + def test__build_headers__authorization(self): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY) + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + def test__build_headers__when_pass_headers__should_append(self): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) + assert client._http_client._headers["foo"] == "bar" + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + @pytest.mark.parametrize( + ids=[ + "when_api_host_is_not_set__should_return_default", + "when_api_host_is_set__should_return_set_value", + ], + argnames=["api_host", "expected_api_host"], + argvalues=[ + (None, "https://api.ai21.com/studio/v1"), + ("http://test_host", "http://test_host/studio/v1"), + ], + ) + def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") + assert client.get_base_url() == expected_api_host + + @pytest.mark.parametrize( + ids=[ + "when_making_request__should_send_appropriate_parameters", + "when_making_request_with_files__should_send_appropriate_post_request", + ], + argnames=["params", "headers"], + argvalues=[ + ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), + ( + {"method": "POST", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}}, + _EXPECTED_POST_FILE_HEADERS, + ), + ], + ) + def test__execute_http_request__( + self, + params, + headers, + dummy_api_host: str, + mock_requests_session: requests.Session, + ): + response_json = {"test_key": "test_value"} + mock_requests_session.request.return_value = MockResponse(response_json, 200) + + http_client = HttpClient(session=mock_requests_session) + client = AI21HTTPClient( + http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1" + ) + + response = client.execute_http_request(**params) + assert response == response_json + + if "files" in params: + # We split it because when calling requests with "files", "params" is turned into "data" + mock_requests_session.request.assert_called_once_with( + timeout=300, + headers=headers, + files=params["files"], + data=params["params"], + url=params["url"], + method=params["method"], + ) + else: + mock_requests_session.request.assert_called_once_with(timeout=300, headers=headers, **params) + + def test__execute_http_request__when_files_with_put_method__should_raise_value_error( + self, + dummy_api_host: str, + mock_requests_session: requests.Session, + ): + response_json = {"test_key": "test_value"} + http_client = HttpClient(session=mock_requests_session) + client = AI21HTTPClient( + http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1" + ) + + mock_requests_session.request.return_value = MockResponse(response_json, 200) + with pytest.raises(ValueError): + params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} + client.execute_http_request(**params) diff --git a/tests/unittests/test_dummy.py b/tests/unittests/test_dummy.py deleted file mode 100644 index 39b433bc..00000000 --- a/tests/unittests/test_dummy.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_assert(): - assert True diff --git a/tests/unittests/tokenizers/__init__.py b/tests/unittests/tokenizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/tokenizers/test_ai21_tokenizer.py b/tests/unittests/tokenizers/test_ai21_tokenizer.py new file mode 100644 index 00000000..33f89d9e --- /dev/null +++ b/tests/unittests/tokenizers/test_ai21_tokenizer.py @@ -0,0 +1,25 @@ +from ai21.tokenizers.factory import get_tokenizer + + +class TestAI21Tokenizer: + def test__count_tokens__should_return_number_of_tokens(self): + expected_number_of_tokens = 8 + tokenizer = get_tokenizer() + + actual_number_of_tokens = tokenizer.count_tokens("Text to Tokenize - Hello world!") + + assert actual_number_of_tokens == expected_number_of_tokens + + def test__tokenize__should_return_list_of_tokens(self): + expected_tokens = ["▁Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"] + tokenizer = get_tokenizer() + + actual_tokens = tokenizer.tokenize("Text to Tokenize - Hello world!") + + assert actual_tokens == expected_tokens + + def test__tokenizer__should_be_singleton__when_called_twice(self): + tokenizer1 = get_tokenizer() + tokenizer2 = get_tokenizer() + + assert tokenizer1 is tokenizer2