Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 43 additions & 34 deletions ai21/clients/studio/resources/studio_library.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations
from typing import Optional, List

from ai21.ai21_http_client import AI21HTTPClient
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.models import FileResponse, LibraryAnswerResponse, LibrarySearchResponse
from ai21.types import NotGiven, NOT_GIVEN
from ai21.utils.typing import remove_not_given


class StudioLibrary(StudioResource):
Expand All @@ -22,14 +25,14 @@ def create(
self,
file_path: str,
*,
path: Optional[str] = None,
labels: Optional[List[str]] = None,
public_url: Optional[str] = None,
path: Optional[str] | NotGiven = NOT_GIVEN,
labels: Optional[List[str]] | NotGiven = NOT_GIVEN,
public_url: Optional[str] | NotGiven = NOT_GIVEN,
**kwargs,
) -> str:
url = f"{self._client.get_base_url()}/{self._module_name}"
files = {"file": open(file_path, "rb")}
body = {"path": path, "labels": labels, "publicUrl": public_url, **kwargs}
body = remove_not_given({"path": path, "labels": labels, "publicUrl": public_url, **kwargs})

raw_response = self._post(url=url, files=files, body=body)

Expand All @@ -44,12 +47,12 @@ def get(self, file_id: str) -> FileResponse:
def list(
self,
*,
offset: Optional[int] = None,
limit: Optional[int] = None,
offset: Optional[int] | NotGiven = NOT_GIVEN,
limit: Optional[int] | NotGiven = NOT_GIVEN,
**kwargs,
) -> List[FileResponse]:
url = f"{self._client.get_base_url()}/{self._module_name}"
params = {"offset": offset, "limit": limit}
params = remove_not_given({"offset": offset, "limit": limit})
raw_response = self._get(url=url, params=params)

return [FileResponse.from_dict(file) for file in raw_response]
Expand All @@ -58,16 +61,18 @@ def update(
self,
file_id: str,
*,
public_url: Optional[str] = None,
labels: Optional[List[str]] = None,
public_url: Optional[str] | NotGiven = NOT_GIVEN,
labels: Optional[List[str]] | NotGiven = NOT_GIVEN,
**kwargs,
) -> None:
url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}"
body = {
"publicUrl": public_url,
"labels": labels,
**kwargs,
}
body = remove_not_given(
{
"publicUrl": public_url,
"labels": labels,
**kwargs,
}
)
self._put(url=url, body=body)

def delete(self, file_id: str) -> None:
Expand All @@ -82,19 +87,21 @@ def create(
self,
query: str,
*,
path: Optional[str] = None,
field_ids: Optional[List[str]] = None,
max_segments: Optional[int] = None,
path: Optional[str] | NotGiven = NOT_GIVEN,
field_ids: Optional[List[str]] | NotGiven = NOT_GIVEN,
max_segments: Optional[int] | NotGiven = NOT_GIVEN,
**kwargs,
) -> LibrarySearchResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
body = {
"query": query,
"path": path,
"fieldIds": field_ids,
"maxSegments": max_segments,
**kwargs,
}
body = remove_not_given(
{
"query": query,
"path": path,
"fieldIds": field_ids,
"maxSegments": max_segments,
**kwargs,
}
)
raw_response = self._post(url=url, body=body)
return LibrarySearchResponse.from_dict(raw_response)

Expand All @@ -106,18 +113,20 @@ def create(
self,
question: str,
*,
path: Optional[str] = None,
field_ids: Optional[List[str]] = None,
labels: Optional[List[str]] = None,
path: Optional[str] | NotGiven = NOT_GIVEN,
field_ids: Optional[List[str]] | NotGiven = NOT_GIVEN,
labels: Optional[List[str]] | NotGiven = NOT_GIVEN,
**kwargs,
) -> LibraryAnswerResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
body = {
"question": question,
"path": path,
"fieldIds": field_ids,
"labels": labels,
**kwargs,
}
body = remove_not_given(
{
"question": question,
"path": path,
"fieldIds": field_ids,
"labels": labels,
**kwargs,
}
)
raw_response = self._post(url=url, body=body)
return LibraryAnswerResponse.from_dict(raw_response)
133 changes: 71 additions & 62 deletions ai21/http_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
from typing import Optional, Dict, Any, BinaryIO

import requests
from requests.adapters import HTTPAdapter, Retry, RetryError
import httpx
from httpx import ConnectError
from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential, RetryError

from ai21.errors import (
BadRequest,
Expand All @@ -17,8 +18,8 @@

DEFAULT_TIMEOUT_SEC = 300
DEFAULT_NUM_RETRIES = 0
TIME_BETWEEN_RETRIES = 1
RETRY_BACK_OFF_FACTOR = 0.5
TIME_BETWEEN_RETRIES = 1000
RETRY_ERROR_CODES = (408, 429, 500, 503)
RETRY_METHOD_WHITELIST = ["GET", "POST", "PUT"]

Expand All @@ -39,25 +40,16 @@ def handle_non_success_response(status_code: int, response_text: str):
raise AI21APIError(status_code, details=response_text)


def requests_retry_session(session, retries=0):
retry = Retry(
total=retries,
read=retries,
connect=retries,
backoff_factor=RETRY_BACK_OFF_FACTOR,
status_forcelist=RETRY_ERROR_CODES,
allowed_methods=frozenset(RETRY_METHOD_WHITELIST),
def _requests_retry_session(retries: int) -> httpx.HTTPTransport:
return httpx.HTTPTransport(
retries=retries,
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("https://", adapter)
session.mount("http://", adapter)
return session


class HttpClient:
def __init__(
self,
session: Optional[requests.Session] = None,
client: Optional[httpx.Client] = None,
timeout_sec: int = None,
num_retries: int = None,
headers: Dict = None,
Expand All @@ -66,7 +58,18 @@ def __init__(
self._num_retries = num_retries or DEFAULT_NUM_RETRIES
self._headers = headers or {}
self._apply_retry_policy = self._num_retries > 0
self._session = self._init_session(session)
self._client = self._init_client(client)

# Since we can't use the retry decorator on a method of a class as we can't access class attributes,
# we have to wrap the method in a function
self._request = retry(
wait=wait_exponential(multiplier=RETRY_BACK_OFF_FACTOR, min=TIME_BETWEEN_RETRIES),
retry=retry_if_result(self._should_retry),
stop=stop_after_attempt(self._num_retries),
)(self._request)

def _should_retry(self, response: httpx.Response) -> bool:
return response.status_code in RETRY_ERROR_CODES and response.request.method in RETRY_METHOD_WHITELIST

def execute_http_request(
self,
Expand All @@ -75,66 +78,72 @@ def execute_http_request(
params: Optional[Dict] = None,
files: Optional[Dict[str, BinaryIO]] = None,
):
timeout = self._timeout_sec
headers = self._headers
data = json.dumps(params).encode()
logger.debug(f"Calling {method} {url} {headers} {data}")

try:
if method == "GET":
response = self._session.request(
method=method,
url=url,
headers=headers,
timeout=timeout,
params=params,
)
elif files is not None:
if method != "POST":
raise ValueError(
f"execute_http_request supports only POST for files upload, but {method} was supplied instead"
)
if "Content-Type" in headers:
headers.pop(
"Content-Type"
) # multipart/form-data 'Content-Type' is being added when passing rb files and payload
response = self._session.request(
method=method,
url=url,
headers=headers,
data=params,
files=files,
timeout=timeout,
)
response = self._request(files=files, method=method, params=params, url=url)
except RetryError as retry_error:
last_attempt = retry_error.last_attempt

if last_attempt.failed:
raise last_attempt.exception()
else:
response = self._session.request(method=method, url=url, headers=headers, data=data, timeout=timeout)
except ConnectionError as connection_error:
response = last_attempt.result()

except ConnectError as connection_error:
logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}")
raise connection_error
except RetryError as retry_error:
logger.error(
f"Calling {method} {url} failed with RetryError after {self._num_retries} attempts: {retry_error}"
)
raise retry_error
except Exception as exception:
logger.error(f"Calling {method} {url} failed with Exception: {exception}")
raise exception

if response.status_code != 200:
if response.status_code != httpx.codes.OK:
logger.error(f"Calling {method} {url} failed with a non-200 response code: {response.status_code}")
handle_non_success_response(response.status_code, response.text)

return response.json()

def _init_session(self, session: Optional[requests.Session]) -> requests.Session:
if session is not None:
return session
def _request(
self, files: Optional[Dict[str, BinaryIO]], method: str, params: Optional[Dict], url: str
) -> httpx.Response:
timeout = self._timeout_sec
headers = self._headers
logger.debug(f"Calling {method} {url} {headers} {params}")

if method == "GET":
return self._client.request(
method=method,
url=url,
headers=headers,
timeout=timeout,
params=params,
)

return (
requests_retry_session(requests.Session(), retries=self._num_retries)
if self._apply_retry_policy
else requests.Session()
if files is not None:
if method != "POST":
raise ValueError(
f"execute_http_request supports only POST for files upload, but {method} was supplied instead"
)
if "Content-Type" in headers:
headers.pop(
"Content-Type"
) # multipart/form-data 'Content-Type' is being added when passing rb files and payload
data = params
else:
data = json.dumps(params).encode() if params else None

return self._client.request(
method=method,
url=url,
headers=headers,
data=data,
timeout=timeout,
files=files,
)

def _init_client(self, client: Optional[httpx.Client]) -> httpx.Client:
if client is not None:
return client

return _requests_retry_session(retries=self._num_retries) if self._apply_retry_policy else httpx.Client()

def add_headers(self, headers: Dict[str, Any]) -> None:
self._headers.update(headers)
Loading