Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions ai21/clients/bedrock/resources/bedrock_completion.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,45 @@
from typing import Optional, List
from __future__ import annotations

from typing import List

from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource
from ai21.models import Penalty, CompletionsResponse
from ai21.types import NotGiven, NOT_GIVEN
from ai21.utils.typing import remove_not_given


class BedrockCompletion(BedrockResource):
def create(
self,
prompt: str,
*,
max_tokens: Optional[int] = None,
num_results: Optional[int] = 1,
min_tokens: Optional[int] = 0,
temperature: Optional[float] = 0.7,
top_p: Optional[int] = 1,
top_k_return: Optional[int] = 0,
stop_sequences: Optional[List[str]] = None,
frequency_penalty: Optional[Penalty] = None,
presence_penalty: Optional[Penalty] = None,
count_penalty: Optional[Penalty] = None,
max_tokens: int | NotGiven = NOT_GIVEN,
num_results: int | NotGiven = NOT_GIVEN,
min_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
top_k_return: int | NotGiven = NOT_GIVEN,
stop_sequences: List[str] | NotGiven = NOT_GIVEN,
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
count_penalty: Penalty | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
body = {
"prompt": prompt,
"maxTokens": max_tokens,
"numResults": num_results,
"minTokens": min_tokens,
"temperature": temperature,
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences or [],
}

if frequency_penalty is not None:
body["frequencyPenalty"] = frequency_penalty.to_dict()

if presence_penalty is not None:
body["presencePenalty"] = presence_penalty.to_dict()

if count_penalty is not None:
body["countPenalty"] = count_penalty.to_dict()
body = remove_not_given(
{
"prompt": prompt,
"maxTokens": max_tokens,
"numResults": num_results,
"minTokens": min_tokens,
"temperature": temperature,
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences or [],
"frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty,
"presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty,
"countPenalty": count_penalty.to_dict() if count_penalty else count_penalty,
}
)

model_id = kwargs.get("model_id", self._model_id)

Expand Down
88 changes: 47 additions & 41 deletions ai21/clients/common/completion_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional, List, Dict, Any
from typing import List, Dict, Any

from ai21.models import Penalty, CompletionsResponse
from ai21.types import NOT_GIVEN, NotGiven
from ai21.utils.typing import remove_not_given


class Completion(ABC):
Expand All @@ -13,18 +17,18 @@ def create(
model: str,
prompt: str,
*,
max_tokens: int = 64,
num_results: int = 1,
min_tokens=0,
temperature=0.7,
top_p=1,
top_k_return=0,
custom_model: Optional[str] = None,
stop_sequences: Optional[List[str]] = (),
frequency_penalty: Optional[Penalty] = None,
presence_penalty: Optional[Penalty] = None,
count_penalty: Optional[Penalty] = None,
epoch: Optional[int] = None,
max_tokens: int | NotGiven = NOT_GIVEN,
num_results: int | NotGiven = NOT_GIVEN,
min_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NOT_GIVEN = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
top_k_return: int | NotGiven = NOT_GIVEN,
custom_model: str | NotGiven = NOT_GIVEN,
stop_sequences: List[str] | NotGiven = NOT_GIVEN,
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
count_penalty: Penalty | NotGiven = NOT_GIVEN,
epoch: int | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
"""
Expand Down Expand Up @@ -54,32 +58,34 @@ def _create_body(
self,
model: str,
prompt: str,
max_tokens: Optional[int],
num_results: Optional[int],
min_tokens: Optional[int],
temperature: Optional[float],
top_p: Optional[int],
top_k_return: Optional[int],
custom_model: Optional[str],
stop_sequences: Optional[List[str]],
frequency_penalty: Optional[Penalty],
presence_penalty: Optional[Penalty],
count_penalty: Optional[Penalty],
epoch: Optional[int],
max_tokens: int | NotGiven,
num_results: int | NotGiven,
min_tokens: int | NotGiven,
temperature: float | NotGiven,
top_p: float | NotGiven,
top_k_return: int | NotGiven,
custom_model: str | NotGiven,
stop_sequences: List[str] | NotGiven,
frequency_penalty: Penalty | NotGiven,
presence_penalty: Penalty | NotGiven,
count_penalty: Penalty | NotGiven,
epoch: int | NotGiven,
):
return {
"model": model,
"customModel": custom_model,
"prompt": prompt,
"maxTokens": max_tokens,
"numResults": num_results,
"minTokens": min_tokens,
"temperature": temperature,
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences or [],
"frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(),
"presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(),
"countPenalty": None if count_penalty is None else count_penalty.to_dict(),
"epoch": epoch,
}
return remove_not_given(
{
"model": model,
"customModel": custom_model,
"prompt": prompt,
"maxTokens": max_tokens,
"numResults": num_results,
"minTokens": min_tokens,
"temperature": temperature,
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences,
"frequencyPenalty": NOT_GIVEN if frequency_penalty is NOT_GIVEN else frequency_penalty.to_dict(),
"presencePenalty": NOT_GIVEN if presence_penalty is NOT_GIVEN else presence_penalty.to_dict(),
"countPenalty": NOT_GIVEN if count_penalty is NOT_GIVEN else count_penalty.to_dict(),
"epoch": epoch,
}
)
58 changes: 28 additions & 30 deletions ai21/clients/sagemaker/resources/sagemaker_completion.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,43 @@
from typing import Optional, List
from typing import List

from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.models import Penalty, CompletionsResponse
from ai21.types import NotGiven, NOT_GIVEN
from ai21.utils.typing import remove_not_given


class SageMakerCompletion(SageMakerResource):
def create(
self,
prompt: str,
*,
max_tokens: Optional[int] = None,
num_results: Optional[int] = 1,
min_tokens: Optional[int] = 0,
temperature: Optional[float] = 0.7,
top_p: Optional[int] = 1,
top_k_return: Optional[int] = 0,
stop_sequences: Optional[List[str]] = None,
frequency_penalty: Optional[Penalty] = None,
presence_penalty: Optional[Penalty] = None,
count_penalty: Optional[Penalty] = None,
max_tokens: int | NotGiven = NOT_GIVEN,
num_results: int | NotGiven = NOT_GIVEN,
min_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
top_k_return: int | NotGiven = NOT_GIVEN,
stop_sequences: List[str] | NotGiven = NOT_GIVEN,
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
count_penalty: Penalty | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
body = {
"prompt": prompt,
"maxTokens": max_tokens,
"numResults": num_results,
"minTokens": min_tokens,
"temperature": temperature,
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences or [],
}

if frequency_penalty is not None:
body["frequencyPenalty"] = frequency_penalty.to_dict()

if presence_penalty is not None:
body["presencePenalty"] = presence_penalty.to_dict()

if count_penalty is not None:
body["countPenalty"] = count_penalty.to_dict()
body = remove_not_given(
{
"prompt": prompt,
"maxTokens": max_tokens,
"numResults": num_results,
"minTokens": min_tokens,
"temperature": temperature,
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences or [],
"frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty,
"presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty,
"countPenalty": count_penalty.to_dict() if count_penalty else count_penalty,
}
)

raw_response = self._invoke(body)

Expand Down
31 changes: 17 additions & 14 deletions ai21/clients/studio/resources/studio_completion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Optional, List
from __future__ import annotations

from typing import List

from ai21.clients.common.completion_base import Completion
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.models import Penalty, CompletionsResponse
from ai21.types import NOT_GIVEN, NotGiven


class StudioCompletion(StudioResource, Completion):
Expand All @@ -11,23 +14,23 @@ def create(
model: str,
prompt: str,
*,
max_tokens: Optional[int] = None,
num_results: Optional[int] = 1,
min_tokens: Optional[int] = 0,
temperature: Optional[float] = 0.7,
top_p: Optional[float] = 1,
top_k_return: Optional[int] = 0,
custom_model: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
frequency_penalty: Optional[Penalty] = None,
presence_penalty: Optional[Penalty] = None,
count_penalty: Optional[Penalty] = None,
epoch: Optional[int] = None,
max_tokens: int | NotGiven = NOT_GIVEN,
num_results: int | NotGiven = NOT_GIVEN,
min_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
top_k_return: int | NotGiven = NOT_GIVEN,
custom_model: str | NotGiven = NOT_GIVEN,
stop_sequences: List[str] | NotGiven = NOT_GIVEN,
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
count_penalty: Penalty | NotGiven = NOT_GIVEN,
epoch: int | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
url = f"{self._client.get_base_url()}/{model}"

if custom_model is not None:
if custom_model:
url = f"{url}/{custom_model}"

url = f"{url}/{self._module_name}"
Expand Down
3 changes: 3 additions & 0 deletions ai21/models/ai21_base_model_mixin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from dataclasses_json import LetterCase, DataClassJsonMixin

from ai21.utils.typing import is_not_given


class AI21BaseModelMixin(DataClassJsonMixin):
dataclass_json_config = {
"letter_case": LetterCase.CAMEL,
"exclude": is_not_given,
}
14 changes: 8 additions & 6 deletions ai21/models/penalty.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from ai21.types import NOT_GIVEN, NotGiven
from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin


@dataclass
class Penalty(AI21BaseModelMixin):
scale: float
apply_to_whitespaces: Optional[bool] = None
apply_to_punctuation: Optional[bool] = None
apply_to_numbers: Optional[bool] = None
apply_to_stopwords: Optional[bool] = None
apply_to_emojis: Optional[bool] = None
apply_to_whitespaces: bool | NotGiven = NOT_GIVEN
apply_to_punctuation: bool | NotGiven = NOT_GIVEN
apply_to_numbers: bool | NotGiven = NOT_GIVEN
apply_to_stopwords: bool | NotGiven = NOT_GIVEN
apply_to_emojis: bool | NotGiven = NOT_GIVEN
30 changes: 30 additions & 0 deletions ai21/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing_extensions import Literal


# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).

For example:

```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:
...


get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""

def __bool__(self) -> Literal[False]:
return False

def __repr__(self) -> str:
return "NOT_GIVEN"


NOT_GIVEN = NotGiven()
Empty file added ai21/utils/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions ai21/utils/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Dict

from ai21.types import NotGiven


def is_not_given(value: Any) -> bool:
return isinstance(value, NotGiven)


def remove_not_given(body: Dict[str, Any]) -> Dict[str, Any]:
return {k: v for k, v in body.items() if not is_not_given(v)}


def to_camel_case(snake_str: str) -> str:
return "".join(x.capitalize() for x in snake_str.lower().split("_"))


def to_lower_camel_case(snake_str: str) -> str:
# We capitalize the first letter of each component except the first one
# with the 'capitalize' method and join them together.
camel_string = to_camel_case(snake_str)
return snake_str[0].lower() + camel_string[1:]
Loading