Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions .github/workflows/quality-checks.yml

This file was deleted.

33 changes: 27 additions & 6 deletions ai21/ai21_studio_client.py → ai21/ai21_http_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import io
from typing import Optional, Dict, Any


from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.errors import MissingApiKeyException
from ai21.http_client import HttpClient
from ai21.version import VERSION


class AI21StudioClient:
class AI21HTTPClient:
def __init__(
self,
*,
Expand All @@ -17,7 +19,9 @@ def __init__(
timeout_sec: Optional[int] = None,
num_retries: Optional[int] = None,
organization: Optional[str] = None,
application: Optional[str] = None,
via: Optional[str] = None,
http_client: Optional[HttpClient] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
):
self._env_config = env_config
Expand All @@ -32,12 +36,11 @@ def __init__(
self._timeout_sec = timeout_sec or self._env_config.timeout_sec
self._num_retries = num_retries or self._env_config.num_retries
self._organization = organization or self._env_config.organization
self._application = self._env_config.application
self._application = application or self._env_config.application
self._via = via

headers = self._build_headers(passed_headers=headers)

self.http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers)
self._http_client = self._init_http_client(http_client=http_client, headers=headers)

def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]:
headers = {
Expand All @@ -53,6 +56,18 @@ def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str,

return headers

def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str, Any]) -> HttpClient:
if http_client is None:
return HttpClient(
timeout_sec=self._timeout_sec,
num_retries=self._num_retries,
headers=headers,
)

http_client.add_headers(headers)

return http_client

def _build_user_agent(self) -> str:
user_agent = f"ai21 studio SDK {VERSION}"

Expand All @@ -67,8 +82,14 @@ def _build_user_agent(self) -> str:

return user_agent

def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None):
return self.http_client.execute_http_request(method=method, url=url, params=params, files=files)
def execute_http_request(
self,
method: str,
url: str,
params: Optional[Dict] = None,
files: Optional[Dict[str, io.TextIOWrapper]] = None,
):
return self._http_client.execute_http_request(method=method, url=url, params=params, files=files)

def get_base_url(self) -> str:
return f"{self._api_host}/studio/{self._api_version}"
33 changes: 18 additions & 15 deletions ai21/clients/studio/ai21_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Any, Dict

from ai21.ai21_studio_client import AI21StudioClient
from ai21.ai21_http_client import AI21HTTPClient
from ai21.clients.studio.resources.studio_answer import StudioAnswer
from ai21.clients.studio.resources.studio_chat import StudioChat
from ai21.clients.studio.resources.studio_completion import StudioCompletion
Expand All @@ -14,6 +14,7 @@
from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation
from ai21.clients.studio.resources.studio_summarize import StudioSummarize
from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment
from ai21.http_client import HttpClient
from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer
from ai21.tokenizers.factory import get_tokenizer

Expand All @@ -33,29 +34,31 @@ def __init__(
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
via: Optional[str] = None,
http_client: Optional[HttpClient] = None,
**kwargs,
):
studio_client = AI21StudioClient(
self._http_client = AI21HTTPClient(
api_key=api_key,
api_host=api_host,
headers=headers,
timeout_sec=timeout_sec,
num_retries=num_retries,
via=via,
http_client=http_client,
)
self.completion = StudioCompletion(studio_client)
self.chat = StudioChat(studio_client)
self.summarize = StudioSummarize(studio_client)
self.embed = StudioEmbed(studio_client)
self.gec = StudioGEC(studio_client)
self.improvements = StudioImprovements(studio_client)
self.paraphrase = StudioParaphrase(studio_client)
self.summarize_by_segment = StudioSummarizeBySegment(studio_client)
self.custom_model = StudioCustomModel(studio_client)
self.dataset = StudioDataset(studio_client)
self.answer = StudioAnswer(studio_client)
self.library = StudioLibrary(studio_client)
self.segmentation = StudioSegmentation(studio_client)
self.completion = StudioCompletion(self._http_client)
self.chat = StudioChat(self._http_client)
self.summarize = StudioSummarize(self._http_client)
self.embed = StudioEmbed(self._http_client)
self.gec = StudioGEC(self._http_client)
self.improvements = StudioImprovements(self._http_client)
self.paraphrase = StudioParaphrase(self._http_client)
self.summarize_by_segment = StudioSummarizeBySegment(self._http_client)
self.custom_model = StudioCustomModel(self._http_client)
self.dataset = StudioDataset(self._http_client)
self.answer = StudioAnswer(self._http_client)
self.library = StudioLibrary(self._http_client)
self.segmentation = StudioSegmentation(self._http_client)

def count_tokens(self, text: str) -> int:
# We might want to cache the tokenizer instance within the class
Expand Down
2 changes: 1 addition & 1 deletion ai21/clients/studio/resources/studio_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ def create(
presence_penalty=presence_penalty,
count_penalty=count_penalty,
)
url = f"{self._client.get_base_url()}/{model}/{self._module_name}"
url = f"{self._client.get_base_url()}/{model}/{self._MODULE_NAME}"
response = self._post(url=url, body=body)
return self._json_to_response(response)
2 changes: 1 addition & 1 deletion ai21/clients/studio/resources/studio_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def create(
num_results: Optional[int] = 1,
min_tokens: Optional[int] = 0,
temperature: Optional[float] = 0.7,
top_p: Optional[int] = 1,
top_p: Optional[float] = 1,
top_k_return: Optional[int] = 0,
custom_model: Optional[str] = None,
experimental_mode: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions ai21/clients/studio/resources/studio_library.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, List

from ai21.ai21_studio_client import AI21StudioClient
from ai21.ai21_http_client import AI21HTTPClient
from ai21.resources.responses.file_response import FileResponse
from ai21.resources.responses.library_answer_response import LibraryAnswerResponse
from ai21.resources.responses.library_search_response import LibrarySearchResponse
Expand All @@ -10,7 +10,7 @@
class StudioLibrary(StudioResource):
_module_name = "library/files"

def __init__(self, client: AI21StudioClient):
def __init__(self, client: AI21HTTPClient):
super().__init__(client)
self.files = LibraryFiles(client)
self.search = LibrarySearch(client)
Expand Down
62 changes: 38 additions & 24 deletions ai21/http_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import json
from typing import Optional, Dict
from typing import Optional, Dict, Any

import requests
from requests.adapters import HTTPAdapter, Retry, RetryError
Expand Down Expand Up @@ -55,34 +56,35 @@ def requests_retry_session(session, retries=0):


class HttpClient:
def __init__(self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None):
self.timeout_sec = timeout_sec if timeout_sec is not None else DEFAULT_TIMEOUT_SEC
self.num_retries = num_retries if num_retries is not None else DEFAULT_NUM_RETRIES
self.headers = headers if headers is not None else {}
self.apply_retry_policy = self.num_retries > 0
def __init__(
self,
session: Optional[requests.Session] = None,
timeout_sec: int = None,
num_retries: int = None,
headers: Dict = None,
):
self._timeout_sec = timeout_sec or DEFAULT_TIMEOUT_SEC
self._num_retries = num_retries or DEFAULT_NUM_RETRIES
self._headers = headers or {}
self._apply_retry_policy = self._num_retries > 0
self._session = self._init_session(session)

def execute_http_request(
self,
method: str,
url: str,
params: Optional[Dict] = None,
files=None,
auth=None,
files: Optional[Dict[str, io.TextIOWrapper]] = None,
):
session = (
requests_retry_session(requests.Session(), retries=self.num_retries)
if self.apply_retry_policy
else requests.Session()
)
timeout = self.timeout_sec
headers = self.headers
timeout = self._timeout_sec
headers = self._headers
data = json.dumps(params).encode()
logger.info(f"Calling {method} {url} {headers} {data}")
try:
if method == "GET":
response = session.request(
method,
url,
response = self._session.request(
method=method,
url=url,
headers=headers,
timeout=timeout,
params=params,
Expand All @@ -96,23 +98,22 @@ def execute_http_request(
headers.pop(
"Content-Type"
) # multipart/form-data 'Content-Type' is being added when passing rb files and payload
response = session.request(
method,
url,
response = self._session.request(
method=method,
url=url,
headers=headers,
data=params,
files=files,
timeout=timeout,
auth=auth,
)
else:
response = session.request(method, url, headers=headers, data=data, timeout=timeout, auth=auth)
response = self._session.request(method=method, url=url, headers=headers, data=data, timeout=timeout)
except ConnectionError as connection_error:
logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}")
raise connection_error
except RetryError as retry_error:
logger.error(
f"Calling {method} {url} failed with RetryError after {self.num_retries} attempts: {retry_error}"
f"Calling {method} {url} failed with RetryError after {self._num_retries} attempts: {retry_error}"
)
raise retry_error
except Exception as exception:
Expand All @@ -124,3 +125,16 @@ def execute_http_request(
handle_non_success_response(response.status_code, response.text)

return response.json()

def _init_session(self, session: Optional[requests.Session]) -> requests.Session:
if session is not None:
return session

return (
requests_retry_session(requests.Session(), retries=self._num_retries)
if self._apply_retry_policy
else requests.Session()
)

def add_headers(self, headers: Dict[str, Any]) -> None:
self._headers.update(headers)
2 changes: 1 addition & 1 deletion ai21/resources/bases/chat_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Message:


class Chat(ABC):
_module_name = "chat"
_MODULE_NAME = "chat"

@abstractmethod
def create(
Expand Down
7 changes: 4 additions & 3 deletions ai21/resources/studio_resource.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from __future__ import annotations

import io
from abc import ABC
from typing import Any, Dict, Optional

from ai21.ai21_studio_client import AI21StudioClient
from ai21.ai21_http_client import AI21HTTPClient


class StudioResource(ABC):
def __init__(self, client: AI21StudioClient):
def __init__(self, client: AI21HTTPClient):
self._client = client

def _post(
self,
url: str,
body: Dict[str, Any],
files: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, io.TextIOWrapper]] = None,
) -> Dict[str, Any]:
return self._client.execute_http_request(
method="POST",
Expand Down
11 changes: 8 additions & 3 deletions ai21/services/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

from ai21.ai21_studio_client import AI21StudioClient
from ai21.ai21_http_client import AI21HTTPClient
from ai21.clients.sagemaker.constants import (
SAGEMAKER_MODEL_PACKAGE_NAMES,
)
Expand All @@ -18,7 +18,7 @@ class SageMaker:
def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str:
_assert_model_package_exists(model_name=model_name, region=region)

client = AI21StudioClient()
client = cls._create_ai21_http_client()

response = client.execute_http_request(
method="POST",
Expand All @@ -40,7 +40,8 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE
@classmethod
def list_model_package_versions(cls, model_name: str, region: str) -> List[str]:
_assert_model_package_exists(model_name=model_name, region=region)
client = AI21StudioClient()

client = cls._create_ai21_http_client()

response = client.execute_http_request(
method="POST",
Expand All @@ -53,6 +54,10 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]:

return response["versions"]

@classmethod
def _create_ai21_http_client(cls) -> AI21HTTPClient:
return AI21HTTPClient()


def _assert_model_package_exists(model_name, region):
if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES:
Expand Down
4 changes: 2 additions & 2 deletions ai21/tokenizers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ def get_tokenizer() -> AI21Tokenizer:
global _cached_tokenizer

if _cached_tokenizer is None:
_cached_tokenizer = Tokenizer.get_tokenizer()
_cached_tokenizer = AI21Tokenizer(Tokenizer.get_tokenizer())

return AI21Tokenizer(_cached_tokenizer)
return _cached_tokenizer
Loading