Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a7be775
test: get_tokenizer tests
asafgardin Dec 21, 2023
22fee63
fix: cases
asafgardin Dec 21, 2023
68fd37e
test: Added some unittests to resources
asafgardin Dec 21, 2023
bdac3c4
fix: rename var
asafgardin Dec 21, 2023
d1f8be1
test: Added ai21 studio client tsts
asafgardin Dec 21, 2023
287b4eb
fix: rename files
asafgardin Dec 21, 2023
8697aee
fix: Added types
asafgardin Dec 24, 2023
0b3a1f6
test: added test to http
asafgardin Dec 24, 2023
c7fa29f
fix: removed unnecessary auth param
asafgardin Dec 24, 2023
37a57cd
test: Added tests
asafgardin Dec 24, 2023
0b4c6be
test: Added sagemaker
asafgardin Dec 25, 2023
c2f15a1
test: Created a single session per instance
asafgardin Dec 25, 2023
e7cf601
ci: removed unnecessary action
asafgardin Dec 25, 2023
68bc456
fix: errors
asafgardin Dec 26, 2023
6118385
fix: error renames
asafgardin Dec 26, 2023
4b3c2f3
fix: rename upload
asafgardin Dec 26, 2023
ae9e8ad
fix: rename type
asafgardin Dec 26, 2023
c55cbee
fix: rename variable
asafgardin Dec 26, 2023
e316760
fix: removed experimental
asafgardin Dec 26, 2023
6dc7614
test: fixed
asafgardin Dec 27, 2023
69f50e7
test: Added some unittests to resources
asafgardin Dec 21, 2023
b6d96ff
test: Added ai21 studio client tsts
asafgardin Dec 21, 2023
9b52690
fix: rename files
asafgardin Dec 21, 2023
c8cba0a
fix: Added types
asafgardin Dec 24, 2023
1d4cf23
test: added test to http
asafgardin Dec 24, 2023
750b57d
fix: removed unnecessary auth param
asafgardin Dec 24, 2023
c469376
test: Added tests
asafgardin Dec 24, 2023
ad6f5c1
test: Added sagemaker
asafgardin Dec 25, 2023
826dd57
test: Created a single session per instance
asafgardin Dec 25, 2023
d1cbea1
fix: errors
asafgardin Dec 26, 2023
6283a99
fix: error renames
asafgardin Dec 26, 2023
4bbd6a2
fix: rename upload
asafgardin Dec 26, 2023
4afdfce
fix: rename type
asafgardin Dec 26, 2023
70bdd9a
fix: rename variable
asafgardin Dec 26, 2023
3d325d9
fix: removed experimental
asafgardin Dec 26, 2023
c8bcf10
test: fixed
asafgardin Dec 27, 2023
2fe23f9
Merge remote-tracking branch 'origin/feedback_fixes' into feedback_fixes
asafgardin Dec 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ from ai21 import AI21Client

client = AI21Client()

file_id = client.library.files.upload(
file_id = client.library.files.create(
file_path="path/to/file",
path="path/to/file/in/library",
labels=["label1", "label2"],
Expand Down Expand Up @@ -213,7 +213,7 @@ try:
except ai21_errors.AI21ServerError as e:
print("Server error and could not be reached")
print(e.details)
except ai21_errors.TooManyRequests as e:
except ai21_errors.TooManyRequestsError as e:
print("A 429 status code was returned. Slow down on the requests")
except AI21APIError as e:
print("A non 200 status code error. For more error types see ai21.errors")
Expand Down
15 changes: 13 additions & 2 deletions ai21/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Any

from ai21.clients.studio.ai21_client import AI21Client
from ai21.errors import AI21APIError, AI21APITimeoutError
from ai21.errors import (
AI21APIError,
APITimeoutError,
MissingApiKeyError,
ModelPackageDoesntExistError,
AI21Error,
TooManyRequestsError,
)
from ai21.logger import setup_logger
from ai21.resources.responses.answer_response import AnswerResponse
from ai21.resources.responses.chat_response import ChatResponse
Expand Down Expand Up @@ -60,7 +67,11 @@ def __getattr__(name: str) -> Any:
__all__ = [
"AI21Client",
"AI21APIError",
"AI21APITimeoutError",
"APITimeoutError",
"AI21Error",
"MissingApiKeyError",
"ModelPackageDoesntExistError",
"TooManyRequestsError",
"AI21BedrockClient",
"AI21SageMakerClient",
"BedrockModelID",
Expand Down
4 changes: 0 additions & 4 deletions ai21/ai21_env_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ class _AI21EnvConfig:
api_key: Optional[str] = None
api_version: str = DEFAULT_API_VERSION
api_host: str = STUDIO_HOST
organization: Optional[str] = None
application: Optional[str] = None
timeout_sec: Optional[int] = None
num_retries: Optional[int] = None
aws_region: Optional[str] = None
Expand All @@ -24,8 +22,6 @@ def from_env(cls) -> _AI21EnvConfig:
api_key=os.getenv("AI21_API_KEY"),
api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION),
api_host=os.getenv("AI21_API_HOST", STUDIO_HOST),
organization=os.getenv("AI21_ORGANIZATION"),
application=os.getenv("AI21_APPLICATION"),
timeout_sec=os.getenv("AI21_TIMEOUT_SEC"),
num_retries=os.getenv("AI21_NUM_RETRIES"),
aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"),
Expand Down
14 changes: 6 additions & 8 deletions ai21/ai21_http_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import io
from typing import Optional, Dict, Any

from typing import Optional, Dict, Any, BinaryIO

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.errors import MissingApiKeyException
from ai21.errors import MissingApiKeyError
from ai21.http_client import HttpClient
from ai21.version import VERSION

Expand All @@ -28,15 +26,15 @@ def __init__(
self._api_key = api_key or self._env_config.api_key

if self._api_key is None:
raise MissingApiKeyException()
raise MissingApiKeyError()

self._api_host = api_host or self._env_config.api_host
self._api_version = api_version or self._env_config.api_version
self._headers = headers
self._timeout_sec = timeout_sec or self._env_config.timeout_sec
self._num_retries = num_retries or self._env_config.num_retries
self._organization = organization or self._env_config.organization
self._application = application or self._env_config.application
self._organization = organization
self._application = application
self._via = via

headers = self._build_headers(passed_headers=headers)
Expand Down Expand Up @@ -87,7 +85,7 @@ def execute_http_request(
method: str,
url: str,
params: Optional[Dict] = None,
files: Optional[Dict[str, io.TextIOWrapper]] = None,
files: Optional[Dict[str, BinaryIO]] = None,
):
return self._http_client.execute_http_request(method=method, url=url, params=params, files=files)

Expand Down
4 changes: 2 additions & 2 deletions ai21/clients/bedrock/bedrock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from botocore.exceptions import ClientError

from ai21.logger import logger
from ai21.errors import AccessDenied, NotFound, AI21APITimeoutError
from ai21.errors import AccessDenied, NotFound, APITimeoutError
from ai21.http_client import handle_non_success_response

_ERROR_MSG_TEMPLATE = (
Expand Down Expand Up @@ -52,7 +52,7 @@ def _handle_client_error(self, client_exception: ClientError) -> None:
raise NotFound(details=error_message)

if status_code == 408:
raise AI21APITimeoutError(details=error_message)
raise APITimeoutError(details=error_message)

if status_code == 424:
error_message_template = re.compile(_ERROR_MSG_TEMPLATE)
Expand Down
2 changes: 1 addition & 1 deletion ai21/clients/studio/resources/studio_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def create(
mode: Optional[str] = None,
**kwargs,
) -> AnswerResponse:
url = f"{self._client.get_base_url()}/{self._MODULE_NAME}"
url = f"{self._client.get_base_url()}/{self._module_name}"

body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode)

Expand Down
2 changes: 1 addition & 1 deletion ai21/clients/studio/resources/studio_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ def create(
presence_penalty=presence_penalty,
count_penalty=count_penalty,
)
url = f"{self._client.get_base_url()}/{model}/{self._MODULE_NAME}"
url = f"{self._client.get_base_url()}/{model}/{self._module_name}"
response = self._post(url=url, body=body)
return self._json_to_response(response)
5 changes: 0 additions & 5 deletions ai21/clients/studio/resources/studio_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,13 @@ def create(
top_p: Optional[float] = 1,
top_k_return: Optional[int] = 0,
custom_model: Optional[str] = None,
experimental_mode: bool = False,
stop_sequences: Optional[List[str]] = None,
frequency_penalty: Optional[Dict[str, Any]] = None,
presence_penalty: Optional[Dict[str, Any]] = None,
count_penalty: Optional[Dict[str, Any]] = None,
epoch: Optional[int] = None,
**kwargs,
) -> CompletionsResponse:
if experimental_mode:
model = f"experimental/{model}"

url = f"{self._client.get_base_url()}/{model}"

if custom_model is not None:
Expand All @@ -45,7 +41,6 @@ def create(
top_p=top_p,
top_k_return=top_k_return,
custom_model=custom_model,
experimental_mode=experimental_mode,
stop_sequences=stop_sequences,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
Expand Down
2 changes: 1 addition & 1 deletion ai21/clients/studio/resources/studio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class StudioDataset(StudioResource, Dataset):
def upload(
def create(
self,
file_path: str,
dataset_name: str,
Expand Down
4 changes: 2 additions & 2 deletions ai21/clients/studio/resources/studio_improvements.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

from ai21.errors import EmptyMandatoryListException
from ai21.errors import EmptyMandatoryListError
from ai21.resources.bases.improvements_base import Improvements
from ai21.resources.responses.improvement_response import ImprovementsResponse
from ai21.resources.studio_resource import StudioResource
Expand All @@ -9,7 +9,7 @@
class StudioImprovements(StudioResource, Improvements):
def create(self, text: str, types: List[str], **kwargs) -> ImprovementsResponse:
if len(types) == 0:
raise EmptyMandatoryListException("types")
raise EmptyMandatoryListError("types")

url = f"{self._client.get_base_url()}/{self._module_name}"
body = self._create_body(text=text, types=types)
Expand Down
2 changes: 1 addition & 1 deletion ai21/clients/studio/resources/studio_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, client: AI21HTTPClient):
class LibraryFiles(StudioResource):
_module_name = "library/files"

def upload(
def create(
self,
file_path: str,
*,
Expand Down
59 changes: 11 additions & 48 deletions ai21/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, details: Optional[str] = None):
super().__init__(404, details)


class AI21APITimeoutError(AI21APIError):
class APITimeoutError(AI21APIError):
def __init__(self, details: Optional[str] = None):
super().__init__(408, details)

Expand All @@ -41,7 +41,7 @@ def __init__(self, details: Optional[str] = None):
super().__init__(422, details)


class TooManyRequests(AI21APIError):
class TooManyRequestsError(AI21APIError):
def __init__(self, details: Optional[str] = None):
super().__init__(429, details)

Expand All @@ -56,7 +56,7 @@ def __init__(self, details: Optional[str] = None):
super().__init__(500, details)


class AI21ClientException(Exception):
class AI21Error(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(message)
Expand All @@ -65,57 +65,14 @@ def __str__(self) -> str:
return f"{type(self).__name__} {self.message}"


class MissingInputException(AI21ClientException):
def __init__(self, field_name: str, call_name: str):
message = f"{field_name} is required for the {call_name} call"
super().__init__(message)


class UnsupportedInputException(AI21ClientException):
def __init__(self, field_name: str, call_name: str):
message = f"{field_name} is unsupported for the {call_name} call"
super().__init__(message)


class UnsupportedDestinationException(AI21ClientException):
def __init__(self, destination_name: str, call_name: str):
message = f'Destination of type {destination_name} is unsupported for the "{call_name}" call'
super().__init__(message)


class OnlyOneInputException(AI21ClientException):
def __init__(self, field_name1: str, field_name2: str, call_name: str):
message = f"{field_name1} or {field_name2} is required for the {call_name} call, but not both"
super().__init__(message)


class WrongInputTypeException(AI21ClientException):
def __init__(self, key: str, expected_type: type, given_type: type):
message = f"Supplied {key} should be {expected_type}, but {given_type} was passed instead"
super().__init__(message)


class EmptyMandatoryListException(AI21ClientException):
def __init__(self, key: str):
message = f"Supplied {key} is empty. At least one element should be present in the list"
super().__init__(message)


class MissingApiKeyException(AI21ClientException):
class MissingApiKeyError(AI21Error):
def __init__(self):
message = "API key must be supplied either globally in the ai21 namespace, or to be provided in the call args"
super().__init__(message)
self.message = message


class NoSpecifiedRegionException(AI21ClientException):
def __init__(self):
message = "No AWS region provided"
super().__init__(message)
self.message = message


class ModelPackageDoesntExistException(AI21ClientException):
class ModelPackageDoesntExistError(AI21Error):
def __init__(self, model_name: str, region: str, version: Optional[str] = None):
message = f"model_name: {model_name} doesn't exist in region: {region}"

Expand All @@ -124,3 +81,9 @@ def __init__(self, model_name: str, region: str, version: Optional[str] = None):

super().__init__(message)
self.message = message


class EmptyMandatoryListError(AI21Error):
def __init__(self, key: str):
message = f"Supplied {key} is empty. At least one element should be present in the list"
super().__init__(message)
11 changes: 5 additions & 6 deletions ai21/http_client.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import io
import json
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, BinaryIO

import requests
from requests.adapters import HTTPAdapter, Retry, RetryError

from ai21.logger import logger
from ai21.errors import (
BadRequest,
Unauthorized,
UnprocessableEntity,
TooManyRequests,
TooManyRequestsError,
AI21ServerError,
ServiceUnavailable,
AI21APIError,
)
from ai21.logger import logger

DEFAULT_TIMEOUT_SEC = 300
DEFAULT_NUM_RETRIES = 0
Expand All @@ -32,7 +31,7 @@ def handle_non_success_response(status_code: int, response_text: str):
if status_code == 422:
raise UnprocessableEntity(details=response_text)
if status_code == 429:
raise TooManyRequests(details=response_text)
raise TooManyRequestsError(details=response_text)
if status_code == 500:
raise AI21ServerError(details=response_text)
if status_code == 503:
Expand Down Expand Up @@ -74,7 +73,7 @@ def execute_http_request(
method: str,
url: str,
params: Optional[Dict] = None,
files: Optional[Dict[str, io.TextIOWrapper]] = None,
files: Optional[Dict[str, BinaryIO]] = None,
):
timeout = self._timeout_sec
headers = self._headers
Expand Down
2 changes: 1 addition & 1 deletion ai21/resources/bases/answer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class Answer(ABC):
_MODULE_NAME = "answer"
_module_name = "answer"

def create(
self,
Expand Down
2 changes: 1 addition & 1 deletion ai21/resources/bases/chat_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Message:


class Chat(ABC):
_MODULE_NAME = "chat"
_module_name = "chat"

@abstractmethod
def create(
Expand Down
3 changes: 0 additions & 3 deletions ai21/resources/bases/completion_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def create(
top_p=1,
top_k_return=0,
custom_model: Optional[str] = None,
experimental_mode: bool = False,
stop_sequences: Optional[List[str]] = (),
frequency_penalty: Optional[Dict[str, Any]] = {},
presence_penalty: Optional[Dict[str, Any]] = {},
Expand All @@ -44,7 +43,6 @@ def _create_body(
top_p: Optional[int],
top_k_return: Optional[int],
custom_model: Optional[str],
experimental_mode: bool,
stop_sequences: Optional[List[str]],
frequency_penalty: Optional[Dict[str, Any]],
presence_penalty: Optional[Dict[str, Any]],
Expand All @@ -54,7 +52,6 @@ def _create_body(
return {
"model": model,
"customModel": custom_model,
"experimentalModel": experimental_mode,
"prompt": prompt,
"maxTokens": max_tokens,
"numResults": num_results,
Expand Down
2 changes: 1 addition & 1 deletion ai21/resources/bases/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Dataset(ABC):
_module_name = "dataset"

@abstractmethod
def upload(
def create(
self,
file_path: str,
dataset_name: str,
Expand Down
Loading