diff --git a/README.md b/README.md index 03cbbae2..ac7edc0d 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,7 @@ from ai21 import AI21Client client = AI21Client() -file_id = client.library.files.upload( +file_id = client.library.files.create( file_path="path/to/file", path="path/to/file/in/library", labels=["label1", "label2"], @@ -213,7 +213,7 @@ try: except ai21_errors.AI21ServerError as e: print("Server error and could not be reached") print(e.details) -except ai21_errors.TooManyRequests as e: +except ai21_errors.TooManyRequestsError as e: print("A 429 status code was returned. Slow down on the requests") except AI21APIError as e: print("A non 200 status code error. For more error types see ai21.errors") diff --git a/ai21/__init__.py b/ai21/__init__.py index d614a1ca..6c5fb3e9 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -1,7 +1,14 @@ from typing import Any from ai21.clients.studio.ai21_client import AI21Client -from ai21.errors import AI21APIError, AI21APITimeoutError +from ai21.errors import ( + AI21APIError, + APITimeoutError, + MissingApiKeyError, + ModelPackageDoesntExistError, + AI21Error, + TooManyRequestsError, +) from ai21.logger import setup_logger from ai21.resources.responses.answer_response import AnswerResponse from ai21.resources.responses.chat_response import ChatResponse @@ -60,7 +67,11 @@ def __getattr__(name: str) -> Any: __all__ = [ "AI21Client", "AI21APIError", - "AI21APITimeoutError", + "APITimeoutError", + "AI21Error", + "MissingApiKeyError", + "ModelPackageDoesntExistError", + "TooManyRequestsError", "AI21BedrockClient", "AI21SageMakerClient", "BedrockModelID", diff --git a/ai21/ai21_env_config.py b/ai21/ai21_env_config.py index 9f3a46be..01ef3501 100644 --- a/ai21/ai21_env_config.py +++ b/ai21/ai21_env_config.py @@ -11,8 +11,6 @@ class _AI21EnvConfig: api_key: Optional[str] = None api_version: str = DEFAULT_API_VERSION api_host: str = STUDIO_HOST - organization: Optional[str] = None - application: Optional[str] = None timeout_sec: Optional[int] = None num_retries: Optional[int] = None aws_region: Optional[str] = None @@ -24,8 +22,6 @@ def from_env(cls) -> _AI21EnvConfig: api_key=os.getenv("AI21_API_KEY"), api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION), api_host=os.getenv("AI21_API_HOST", STUDIO_HOST), - organization=os.getenv("AI21_ORGANIZATION"), - application=os.getenv("AI21_APPLICATION"), timeout_sec=os.getenv("AI21_TIMEOUT_SEC"), num_retries=os.getenv("AI21_NUM_RETRIES"), aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"), diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 9bd3cb82..68007654 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,9 +1,7 @@ -import io -from typing import Optional, Dict, Any - +from typing import Optional, Dict, Any, BinaryIO from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig -from ai21.errors import MissingApiKeyException +from ai21.errors import MissingApiKeyError from ai21.http_client import HttpClient from ai21.version import VERSION @@ -28,15 +26,15 @@ def __init__( self._api_key = api_key or self._env_config.api_key if self._api_key is None: - raise MissingApiKeyException() + raise MissingApiKeyError() self._api_host = api_host or self._env_config.api_host self._api_version = api_version or self._env_config.api_version self._headers = headers self._timeout_sec = timeout_sec or self._env_config.timeout_sec self._num_retries = num_retries or self._env_config.num_retries - self._organization = organization or self._env_config.organization - self._application = application or self._env_config.application + self._organization = organization + self._application = application self._via = via headers = self._build_headers(passed_headers=headers) @@ -87,7 +85,7 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ): return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) diff --git a/ai21/clients/bedrock/bedrock_session.py b/ai21/clients/bedrock/bedrock_session.py index 7d9f846c..82029da6 100644 --- a/ai21/clients/bedrock/bedrock_session.py +++ b/ai21/clients/bedrock/bedrock_session.py @@ -6,7 +6,7 @@ from botocore.exceptions import ClientError from ai21.logger import logger -from ai21.errors import AccessDenied, NotFound, AI21APITimeoutError +from ai21.errors import AccessDenied, NotFound, APITimeoutError from ai21.http_client import handle_non_success_response _ERROR_MSG_TEMPLATE = ( @@ -52,7 +52,7 @@ def _handle_client_error(self, client_exception: ClientError) -> None: raise NotFound(details=error_message) if status_code == 408: - raise AI21APITimeoutError(details=error_message) + raise APITimeoutError(details=error_message) if status_code == 424: error_message_template = re.compile(_ERROR_MSG_TEMPLATE) diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index ba79621e..5cd12fac 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -15,7 +15,7 @@ def create( mode: Optional[str] = None, **kwargs, ) -> AnswerResponse: - url = f"{self._client.get_base_url()}/{self._MODULE_NAME}" + url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 8fe1bca4..f1dab12b 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -39,6 +39,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, ) - url = f"{self._client.get_base_url()}/{model}/{self._MODULE_NAME}" + url = f"{self._client.get_base_url()}/{model}/{self._module_name}" response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 48f85fa7..10c1890f 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -18,7 +18,6 @@ def create( top_p: Optional[float] = 1, top_k_return: Optional[int] = 0, custom_model: Optional[str] = None, - experimental_mode: bool = False, stop_sequences: Optional[List[str]] = None, frequency_penalty: Optional[Dict[str, Any]] = None, presence_penalty: Optional[Dict[str, Any]] = None, @@ -26,9 +25,6 @@ def create( epoch: Optional[int] = None, **kwargs, ) -> CompletionsResponse: - if experimental_mode: - model = f"experimental/{model}" - url = f"{self._client.get_base_url()}/{model}" if custom_model is not None: @@ -45,7 +41,6 @@ def create( top_p=top_p, top_k_return=top_k_return, custom_model=custom_model, - experimental_mode=experimental_mode, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py index 05a07c52..8626d71b 100644 --- a/ai21/clients/studio/resources/studio_dataset.py +++ b/ai21/clients/studio/resources/studio_dataset.py @@ -6,7 +6,7 @@ class StudioDataset(StudioResource, Dataset): - def upload( + def create( self, file_path: str, dataset_name: str, diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index 50895e24..86287781 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -1,6 +1,6 @@ from typing import List -from ai21.errors import EmptyMandatoryListException +from ai21.errors import EmptyMandatoryListError from ai21.resources.bases.improvements_base import Improvements from ai21.resources.responses.improvement_response import ImprovementsResponse from ai21.resources.studio_resource import StudioResource @@ -9,7 +9,7 @@ class StudioImprovements(StudioResource, Improvements): def create(self, text: str, types: List[str], **kwargs) -> ImprovementsResponse: if len(types) == 0: - raise EmptyMandatoryListException("types") + raise EmptyMandatoryListError("types") url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(text=text, types=types) diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index 42daedbb..b8f96a3c 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -20,7 +20,7 @@ def __init__(self, client: AI21HTTPClient): class LibraryFiles(StudioResource): _module_name = "library/files" - def upload( + def create( self, file_path: str, *, diff --git a/ai21/errors.py b/ai21/errors.py index 33cf336b..4a0f8c92 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -31,7 +31,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(404, details) -class AI21APITimeoutError(AI21APIError): +class APITimeoutError(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(408, details) @@ -41,7 +41,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(422, details) -class TooManyRequests(AI21APIError): +class TooManyRequestsError(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(429, details) @@ -56,7 +56,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(500, details) -class AI21ClientException(Exception): +class AI21Error(Exception): def __init__(self, message: str): self.message = message super().__init__(message) @@ -65,57 +65,14 @@ def __str__(self) -> str: return f"{type(self).__name__} {self.message}" -class MissingInputException(AI21ClientException): - def __init__(self, field_name: str, call_name: str): - message = f"{field_name} is required for the {call_name} call" - super().__init__(message) - - -class UnsupportedInputException(AI21ClientException): - def __init__(self, field_name: str, call_name: str): - message = f"{field_name} is unsupported for the {call_name} call" - super().__init__(message) - - -class UnsupportedDestinationException(AI21ClientException): - def __init__(self, destination_name: str, call_name: str): - message = f'Destination of type {destination_name} is unsupported for the "{call_name}" call' - super().__init__(message) - - -class OnlyOneInputException(AI21ClientException): - def __init__(self, field_name1: str, field_name2: str, call_name: str): - message = f"{field_name1} or {field_name2} is required for the {call_name} call, but not both" - super().__init__(message) - - -class WrongInputTypeException(AI21ClientException): - def __init__(self, key: str, expected_type: type, given_type: type): - message = f"Supplied {key} should be {expected_type}, but {given_type} was passed instead" - super().__init__(message) - - -class EmptyMandatoryListException(AI21ClientException): - def __init__(self, key: str): - message = f"Supplied {key} is empty. At least one element should be present in the list" - super().__init__(message) - - -class MissingApiKeyException(AI21ClientException): +class MissingApiKeyError(AI21Error): def __init__(self): message = "API key must be supplied either globally in the ai21 namespace, or to be provided in the call args" super().__init__(message) self.message = message -class NoSpecifiedRegionException(AI21ClientException): - def __init__(self): - message = "No AWS region provided" - super().__init__(message) - self.message = message - - -class ModelPackageDoesntExistException(AI21ClientException): +class ModelPackageDoesntExistError(AI21Error): def __init__(self, model_name: str, region: str, version: Optional[str] = None): message = f"model_name: {model_name} doesn't exist in region: {region}" @@ -124,3 +81,9 @@ def __init__(self, model_name: str, region: str, version: Optional[str] = None): super().__init__(message) self.message = message + + +class EmptyMandatoryListError(AI21Error): + def __init__(self, key: str): + message = f"Supplied {key} is empty. At least one element should be present in the list" + super().__init__(message) diff --git a/ai21/http_client.py b/ai21/http_client.py index 00692e07..0eeac1a1 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,20 +1,19 @@ -import io import json -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, BinaryIO import requests from requests.adapters import HTTPAdapter, Retry, RetryError -from ai21.logger import logger from ai21.errors import ( BadRequest, Unauthorized, UnprocessableEntity, - TooManyRequests, + TooManyRequestsError, AI21ServerError, ServiceUnavailable, AI21APIError, ) +from ai21.logger import logger DEFAULT_TIMEOUT_SEC = 300 DEFAULT_NUM_RETRIES = 0 @@ -32,7 +31,7 @@ def handle_non_success_response(status_code: int, response_text: str): if status_code == 422: raise UnprocessableEntity(details=response_text) if status_code == 429: - raise TooManyRequests(details=response_text) + raise TooManyRequestsError(details=response_text) if status_code == 500: raise AI21ServerError(details=response_text) if status_code == 503: @@ -74,7 +73,7 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ): timeout = self._timeout_sec headers = self._headers diff --git a/ai21/resources/bases/answer_base.py b/ai21/resources/bases/answer_base.py index 0fbce8c0..4b11ff5c 100644 --- a/ai21/resources/bases/answer_base.py +++ b/ai21/resources/bases/answer_base.py @@ -5,7 +5,7 @@ class Answer(ABC): - _MODULE_NAME = "answer" + _module_name = "answer" def create( self, diff --git a/ai21/resources/bases/chat_base.py b/ai21/resources/bases/chat_base.py index e2a67c0d..f85270ee 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/resources/bases/chat_base.py @@ -11,7 +11,7 @@ class Message: class Chat(ABC): - _MODULE_NAME = "chat" + _module_name = "chat" @abstractmethod def create( diff --git a/ai21/resources/bases/completion_base.py b/ai21/resources/bases/completion_base.py index cb286df2..f549306a 100644 --- a/ai21/resources/bases/completion_base.py +++ b/ai21/resources/bases/completion_base.py @@ -20,7 +20,6 @@ def create( top_p=1, top_k_return=0, custom_model: Optional[str] = None, - experimental_mode: bool = False, stop_sequences: Optional[List[str]] = (), frequency_penalty: Optional[Dict[str, Any]] = {}, presence_penalty: Optional[Dict[str, Any]] = {}, @@ -44,7 +43,6 @@ def _create_body( top_p: Optional[int], top_k_return: Optional[int], custom_model: Optional[str], - experimental_mode: bool, stop_sequences: Optional[List[str]], frequency_penalty: Optional[Dict[str, Any]], presence_penalty: Optional[Dict[str, Any]], @@ -54,7 +52,6 @@ def _create_body( return { "model": model, "customModel": custom_model, - "experimentalModel": experimental_mode, "prompt": prompt, "maxTokens": max_tokens, "numResults": num_results, diff --git a/ai21/resources/bases/dataset_base.py b/ai21/resources/bases/dataset_base.py index dd53417c..2be49fc7 100644 --- a/ai21/resources/bases/dataset_base.py +++ b/ai21/resources/bases/dataset_base.py @@ -8,7 +8,7 @@ class Dataset(ABC): _module_name = "dataset" @abstractmethod - def upload( + def create( self, file_path: str, dataset_name: str, diff --git a/ai21/resources/studio_resource.py b/ai21/resources/studio_resource.py index 7752be91..8ece396e 100644 --- a/ai21/resources/studio_resource.py +++ b/ai21/resources/studio_resource.py @@ -1,8 +1,7 @@ from __future__ import annotations -import io from abc import ABC -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, BinaryIO from ai21.ai21_http_client import AI21HTTPClient @@ -15,7 +14,7 @@ def _post( self, url: str, body: Dict[str, Any], - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ) -> Dict[str, Any]: return self._client.execute_http_request( method="POST", diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index b1387622..f51e1ae2 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -4,7 +4,7 @@ from ai21.clients.sagemaker.constants import ( SAGEMAKER_MODEL_PACKAGE_NAMES, ) -from ai21.errors import ModelPackageDoesntExistException +from ai21.errors import ModelPackageDoesntExistError _JUMPSTART_ENDPOINT = "jumpstart" _LIST_VERSIONS_ENDPOINT = f"{_JUMPSTART_ENDPOINT}/list_versions" @@ -33,7 +33,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE arn = response["arn"] if not arn: - raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version) + raise ModelPackageDoesntExistError(model_name=model_name, region=region, version=version) return arn @@ -61,4 +61,4 @@ def _create_ai21_http_client(cls) -> AI21HTTPClient: def _assert_model_package_exists(model_name, region): if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: - raise ModelPackageDoesntExistException(model_name=model_name, region=region) + raise ModelPackageDoesntExistError(model_name=model_name, region=region) diff --git a/examples/studio/dataset.py b/examples/studio/dataset.py index b07d6565..87e587cc 100644 --- a/examples/studio/dataset.py +++ b/examples/studio/dataset.py @@ -3,7 +3,7 @@ file_path = "" client = AI21Client() -client.dataset.upload(file_path=file_path, dataset_name="my_new_ds_name") +client.dataset.create(file_path=file_path, dataset_name="my_new_ds_name") result = client.dataset.list() print(result) first_ds_id = result[0].id diff --git a/examples/studio/library.py b/examples/studio/library.py index e1377200..d693d697 100644 --- a/examples/studio/library.py +++ b/examples/studio/library.py @@ -24,7 +24,7 @@ def validate_file_deleted(): path = os.path.join(file_path, file_name) file_utils.create_file(file_path, file_name, content="test content" * 100) -file_id = client.library.files.upload( +file_id = client.library.files.create( file_path=path, path=file_path, labels=["label1", "label2"], diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 1a921fef..6d94f2a7 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -98,7 +98,6 @@ def get_studio_completion(): "numResults": 1, "topP": 1, "customModel": None, - "experimentalModel": False, "topKReturn": 0, "stopSequences": [], "frequencyPenalty": None, diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index a92c23fe..dd36e1c9 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,6 +1,6 @@ import pytest -from ai21.errors import ModelPackageDoesntExistException +from ai21.errors import ModelPackageDoesntExistError from tests.unittests.services.sagemaker_stub import SageMakerStub _DUMMY_ARN = "some-model-package-id1" @@ -22,7 +22,7 @@ def test__get_model_package_arn__should_return_model_package_arn(self): def test__get_model_package_arn__when_no_arn__should_raise_error(self): SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") def test__list_model_package_versions__should_return_model_package_arn(self): @@ -36,9 +36,9 @@ def test__list_model_package_versions__should_return_model_package_arn(self): assert actual_model_package_arn == _DUMMY_VERSIONS def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1")