Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions ai21/ai21_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
*,
api_key: Optional[str] = None,
requires_api_key: bool = True,
api_host: Optional[str] = None,
base_url: Optional[str] = None,
api_version: Optional[str] = None,
headers: Optional[Dict[str, Any]] = None,
timeout_sec: Optional[int] = None,
Expand All @@ -27,7 +27,7 @@ def __init__(
if requires_api_key and not self._api_key:
raise MissingApiKeyError()

self._api_host = api_host
self._base_url = base_url
self._api_version = api_version
self._headers = headers
self._timeout_sec = timeout_sec
Expand Down Expand Up @@ -76,12 +76,17 @@ def _build_user_agent(self) -> str:
def execute_http_request(
self,
method: str,
url: str,
path: str,
params: Optional[Dict] = None,
body: Optional[Dict] = None,
stream: bool = False,
files: Optional[Dict[str, BinaryIO]] = None,
) -> httpx.Response:
return self._http_client.execute_http_request(method=method, url=url, params=params, files=files, stream=stream)

def get_base_url(self) -> str:
return f"{self._api_host}/studio/{self._api_version}"
return self._http_client.execute_http_request(
method=method,
url=f"{self._base_url}{path}",
params=params or {},
files=files,
stream=stream,
body=body or {},
)
4 changes: 3 additions & 1 deletion ai21/clients/studio/ai21_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ def __init__(
env_config: _AI21EnvConfig = AI21EnvConfig,
**kwargs,
):
base_url = api_host or env_config.api_host

self._http_client = AI21HTTPClient(
api_key=api_key or env_config.api_key,
api_host=api_host or env_config.api_host,
base_url=f"{base_url}/studio/v1",
api_version=env_config.api_version,
headers=headers,
timeout_sec=timeout_sec or env_config.timeout_sec,
Expand Down
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/chat/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ def create(
**kwargs,
)

url = f"{self._client.get_base_url()}/{self._module_name}"
return self._post(
url=url,
path=f"/{self._module_name}",
body=body,
stream=stream or False,
stream_cls=Stream[ChatCompletionChunk],
Expand Down
4 changes: 1 addition & 3 deletions ai21/clients/studio/resources/studio_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ def create(
question: str,
**kwargs,
) -> AnswerResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"

body = self._create_body(context=context, question=question, **kwargs)

return self._post(url=url, body=body, response_cls=AnswerResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=AnswerResponse)
4 changes: 2 additions & 2 deletions ai21/clients/studio/resources/studio_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def create(
count_penalty=count_penalty,
**kwargs,
)
url = f"{self._client.get_base_url()}/{model}/{self._module_name}"
return self._post(url=url, body=body, response_cls=ChatResponse)

return self._post(path=f"/{model}/{self._module_name}", body=body, response_cls=ChatResponse)

@property
def completions(self) -> ChatCompletions:
Expand Down
8 changes: 4 additions & 4 deletions ai21/clients/studio/resources/studio_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def create(
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
url = f"{self._client.get_base_url()}/{model}"
path = f"/{model}"

if custom_model:
url = f"{url}/{custom_model}"
path = f"{path}/{custom_model}"

url = f"{url}/{self._module_name}"
path = f"{path}/{self._module_name}"
body = self._create_body(
model=model,
prompt=prompt,
Expand All @@ -53,4 +53,4 @@ def create(
logit_bias=logit_bias,
**kwargs,
)
return self._post(url=url, body=body, response_cls=CompletionsResponse)
return self._post(path=path, body=body, response_cls=CompletionsResponse)
9 changes: 3 additions & 6 deletions ai21/clients/studio/resources/studio_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def create(
num_epochs: Optional[int] = None,
**kwargs,
) -> None:
url = f"{self._client.get_base_url()}/{self._module_name}"
body = self._create_body(
dataset_id=dataset_id,
model_name=model_name,
Expand All @@ -25,12 +24,10 @@ def create(
num_epochs=num_epochs,
**kwargs,
)
self._post(url=url, body=body, response_cls=None)
self._post(path=f"/{self._module_name}", body=body, response_cls=None)

def list(self) -> List[CustomBaseModelResponse]:
url = f"{self._client.get_base_url()}/{self._module_name}"
return self._get(url=url, response_cls=List[CustomBaseModelResponse])
return self._get(path=f"/{self._module_name}", response_cls=List[CustomBaseModelResponse])

def get(self, resource_id: str) -> CustomBaseModelResponse:
url = f"{self._client.get_base_url()}/{self._module_name}/{resource_id}"
return self._get(url=url, response_cls=CustomBaseModelResponse)
return self._get(path=f"/{self._module_name}/{resource_id}", response_cls=CustomBaseModelResponse)
10 changes: 3 additions & 7 deletions ai21/clients/studio/resources/studio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@ def create(
**kwargs,
)
return self._post(
url=self._base_url(),
path=f"/{self._module_name}",
body=body,
files=files,
)

def list(self) -> List[DatasetResponse]:
return self._get(url=self._base_url(), response_cls=List[DatasetResponse])
return self._get(path=f"/{self._module_name}", response_cls=List[DatasetResponse])

def get(self, dataset_pid: str) -> DatasetResponse:
url = f"{self._base_url()}/{dataset_pid}"
return self._get(url=url, response_cls=DatasetResponse)

def _base_url(self) -> str:
return f"{self._client.get_base_url()}/{self._module_name}"
return self._get(path=f"/{self._module_name}/{dataset_pid}", response_cls=DatasetResponse)
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

class StudioEmbed(StudioResource, Embed):
def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
body = self._create_body(texts=texts, type=type, **kwargs)

return self._post(url=url, body=body, response_cls=EmbedResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=EmbedResponse)
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_gec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@
class StudioGEC(StudioResource, GEC):
def create(self, text: str, **kwargs) -> GECResponse:
body = self._create_body(text=text, **kwargs)
url = f"{self._client.get_base_url()}/{self._module_name}"

return self._post(url=url, body=body, response_cls=GECResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=GECResponse)
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_improvements.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def create(self, text: str, types: List[ImprovementType], **kwargs) -> Improveme
if len(types) == 0:
raise EmptyMandatoryListError("types")

url = f"{self._client.get_base_url()}/{self._module_name}"
body = self._create_body(text=text, types=types, **kwargs)

return self._post(url=url, body=body, response_cls=ImprovementsResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=ImprovementsResponse)
22 changes: 7 additions & 15 deletions ai21/clients/studio/resources/studio_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,15 @@ def create(
public_url: Optional[str] | NotGiven = NOT_GIVEN,
**kwargs,
) -> str:
url = f"{self._client.get_base_url()}/{self._module_name}"
files = {"file": open(file_path, "rb")}
body = remove_not_given({"path": path, "labels": labels, "publicUrl": public_url, **kwargs})

raw_response = self._post(url=url, files=files, body=body, response_cls=dict)
raw_response = self._post(path=f"/{self._module_name}", files=files, body=body, response_cls=dict)

return raw_response["fileId"]

def get(self, file_id: str) -> FileResponse:
url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}"

return self._get(url=url, response_cls=FileResponse)
return self._get(path=f"/{self._module_name}/{file_id}", response_cls=FileResponse)

def list(
self,
Expand All @@ -51,10 +48,9 @@ def list(
limit: Optional[int] | NotGiven = NOT_GIVEN,
**kwargs,
) -> List[FileResponse]:
url = f"{self._client.get_base_url()}/{self._module_name}"
params = remove_not_given({"offset": offset, "limit": limit})

return self._get(url=url, params=params, response_cls=List[FileResponse])
return self._get(path=f"/{self._module_name}", params=params, response_cls=List[FileResponse])

def update(
self,
Expand All @@ -64,19 +60,17 @@ def update(
labels: Optional[List[str]] | NotGiven = NOT_GIVEN,
**kwargs,
) -> None:
url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}"
body = remove_not_given(
{
"publicUrl": public_url,
"labels": labels,
**kwargs,
}
)
self._put(url=url, body=body)
self._put(path=f"/{self._module_name}/{file_id}", body=body)

def delete(self, file_id: str) -> None:
url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}"
self._delete(url=url)
self._delete(path=f"/{self._module_name}/{file_id}")


class LibrarySearch(StudioResource):
Expand All @@ -91,7 +85,6 @@ def create(
max_segments: Optional[int] | NotGiven = NOT_GIVEN,
**kwargs,
) -> LibrarySearchResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
body = remove_not_given(
{
"query": query,
Expand All @@ -102,7 +95,7 @@ def create(
}
)

return self._post(url=url, body=body, response_cls=LibrarySearchResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=LibrarySearchResponse)


class LibraryAnswer(StudioResource):
Expand All @@ -117,7 +110,6 @@ def create(
labels: Optional[List[str]] | NotGiven = NOT_GIVEN,
**kwargs,
) -> LibraryAnswerResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
body = remove_not_given(
{
"question": question,
Expand All @@ -128,4 +120,4 @@ def create(
}
)

return self._post(url=url, body=body, response_cls=LibraryAnswerResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=LibraryAnswerResponse)
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_paraphrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@ def create(
end_index=end_index,
**kwargs,
)
url = f"{self._client.get_base_url()}/{self._module_name}"

return self._post(url=url, body=body, response_cls=ParaphraseResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=ParaphraseResponse)
22 changes: 12 additions & 10 deletions ai21/clients/studio/resources/studio_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,41 @@ def __init__(self, client: AI21HTTPClient):

def _post(
self,
url: str,
body: Dict[str, Any],
path: str,
body: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
response_cls: Optional[ResponseT] = None,
stream_cls: Optional[StreamT] = None,
stream: bool = False,
files: Optional[Dict[str, BinaryIO]] = None,
) -> ResponseT | StreamT:
response = self._client.execute_http_request(
method="POST",
url=url,
path=path,
stream=stream,
params=body or {},
body=body or {},
params=params or {},
files=files,
)

return self._cast_response(stream=stream, response=response, response_cls=response_cls, stream_cls=stream_cls)

def _get(
self, url: str, response_cls: Optional[ResponseT] = None, params: Optional[Dict[str, Any]] = None
self, path: str, response_cls: Optional[ResponseT] = None, params: Optional[Dict[str, Any]] = None
) -> ResponseT | StreamT:
response = self._client.execute_http_request(method="GET", url=url, params=params or {})
response = self._client.execute_http_request(method="GET", path=path, params=params or {})
return self._cast_response(response=response, response_cls=response_cls)

def _put(
self, url: str, response_cls: Optional[ResponseT] = None, body: Dict[str, Any] = None
self, path: str, response_cls: Optional[ResponseT] = None, body: Dict[str, Any] = None
) -> ResponseT | StreamT:
response = self._client.execute_http_request(method="PUT", url=url, params=body or {})
response = self._client.execute_http_request(method="PUT", path=path, body=body or {})
return self._cast_response(response=response, response_cls=response_cls)

def _delete(self, url: str, response_cls: Optional[ResponseT] = None) -> ResponseT | StreamT:
def _delete(self, path: str, response_cls: Optional[ResponseT] = None) -> ResponseT | StreamT:
response = self._client.execute_http_request(
method="DELETE",
url=url,
path=path,
)
return self._cast_response(response=response, response_cls=response_cls)

Expand Down
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@
class StudioSegmentation(StudioResource, Segmentation):
def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse:
body = self._create_body(source=source, source_type=source_type.value, **kwargs)
url = f"{self._client.get_base_url()}/{self._module_name}"

return self._post(url=url, body=body, response_cls=SegmentationResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=SegmentationResponse)
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,5 @@ def create(
summary_method=summary_method,
**kwargs,
)
url = f"{self._client.get_base_url()}/{self._module_name}"

return self._post(url=url, body=body, response_cls=SummarizeResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=SummarizeResponse)
3 changes: 1 addition & 2 deletions ai21/clients/studio/resources/studio_summarize_by_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@ def create(
focus=focus,
**kwargs,
)
url = f"{self._client.get_base_url()}/{self._module_name}"

return self._post(url=url, body=body, response_cls=SummarizeBySegmentResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=SummarizeBySegmentResponse)
9 changes: 6 additions & 3 deletions ai21/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ def execute_http_request(
method: str,
url: str,
params: Optional[Dict] = None,
body: Optional[Dict] = None,
stream: bool = False,
files: Optional[Dict[str, BinaryIO]] = None,
) -> httpx.Response:
try:
response = self._request(files=files, method=method, params=params, url=url, stream=stream)
response = self._request(files=files, method=method, params=params, url=url, stream=stream, body=body)
except RetryError as retry_error:
last_attempt = retry_error.last_attempt

Expand All @@ -107,6 +108,7 @@ def _request(
files: Optional[Dict[str, BinaryIO]],
method: str,
params: Optional[Dict],
body: Optional[Dict],
url: str,
stream: bool,
) -> httpx.Response:
Expand All @@ -130,14 +132,15 @@ def _request(
headers.pop(
"Content-Type"
) # multipart/form-data 'Content-Type' is being added when passing rb files and payload
data = params
data = body
else:
data = json.dumps(params).encode() if params else None
data = json.dumps(body).encode() if body else None

request = self._client.build_request(
method=method,
url=url,
headers=headers,
params=params,
data=data,
timeout=timeout,
files=files,
Expand Down
Loading