Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ai21/clients/common/completion_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def create(
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
count_penalty: Penalty | NotGiven = NOT_GIVEN,
epoch: int | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
"""
Expand All @@ -46,6 +47,9 @@ def create(
:param presence_penalty: A penalty applied to tokens that are already present in the prompt.
:param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
:param epoch:
:param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text
representations of the tokens and the floats are the biases themselves. A positive bias increases generation
probability for a given token and a negative bias decreases it.
:param kwargs:
:return:
"""
Expand All @@ -70,6 +74,7 @@ def _create_body(
presence_penalty: Penalty | NotGiven,
count_penalty: Penalty | NotGiven,
epoch: int | NotGiven,
logit_bias: Dict[str, float] | NotGiven,
):
return remove_not_given(
{
Expand All @@ -87,5 +92,6 @@ def _create_body(
"presencePenalty": NOT_GIVEN if presence_penalty is NOT_GIVEN else presence_penalty.to_dict(),
"countPenalty": NOT_GIVEN if count_penalty is NOT_GIVEN else count_penalty.to_dict(),
"epoch": epoch,
"logitBias": logit_bias,
}
)
22 changes: 21 additions & 1 deletion ai21/clients/sagemaker/resources/sagemaker_completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict

from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.models import Penalty, CompletionsResponse
Expand All @@ -21,8 +21,27 @@ def create(
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
count_penalty: Penalty | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
"""
:param prompt: Text for model to complete
:param max_tokens: The maximum number of tokens to generate per result
:param num_results: Number of completions to sample and return.
:param min_tokens: The minimum number of tokens to generate per result.
:param temperature: A value controlling the "creativity" of the model's responses.
:param top_p: A value controlling the diversity of the model's responses.
:param top_k_return: The number of top-scoring tokens to consider for each generation step.
:param stop_sequences: Stops decoding if any of the strings is generated
:param frequency_penalty: A penalty applied to tokens that are frequently generated.
:param presence_penalty: A penalty applied to tokens that are already present in the prompt.
:param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
:param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text
representations of the tokens and the floats are the biases themselves. A positive bias increases generation
probability for a given token and a negative bias decreases it.
:param kwargs:
:return:
"""
body = remove_not_given(
{
"prompt": prompt,
Expand All @@ -36,6 +55,7 @@ def create(
"frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty,
"presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty,
"countPenalty": count_penalty.to_dict() if count_penalty else count_penalty,
"logitBias": logit_bias,
}
)

Expand Down
4 changes: 3 additions & 1 deletion ai21/clients/studio/resources/studio_completion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import List
from typing import List, Dict

from ai21.clients.common.completion_base import Completion
from ai21.clients.studio.resources.studio_resource import StudioResource
Expand All @@ -26,6 +26,7 @@ def create(
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
count_penalty: Penalty | NotGiven = NOT_GIVEN,
epoch: int | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
url = f"{self._client.get_base_url()}/{model}"
Expand All @@ -49,5 +50,6 @@ def create(
presence_penalty=presence_penalty,
count_penalty=count_penalty,
epoch=epoch,
logit_bias=logit_bias,
)
return self._json_to_response(self._post(url=url, body=body))
1 change: 1 addition & 0 deletions examples/studio/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
num_results=1,
custom_model=None,
epoch=1,
logit_bias={"▁I'm▁sorry": -100.0},
count_penalty=Penalty(
scale=0,
apply_to_emojis=False,
Expand Down
23 changes: 23 additions & 0 deletions tests/integration_tests/clients/studio/test_completion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from typing import Dict
from ai21 import AI21Client
from ai21.models import Penalty

Expand All @@ -25,6 +26,7 @@ def test_completion():
num_results=num_results,
custom_model=None,
epoch=1,
logit_bias={"▁a▁box▁of": -100.0},
count_penalty=Penalty(
scale=0,
apply_to_emojis=False,
Expand Down Expand Up @@ -110,3 +112,24 @@ def test_completion_when_finish_reason_defined__should_halt_on_expected_reason(
)

assert response.completions[0].finish_reason.reason == reason


@pytest.mark.parametrize(
ids=[
"no_logit_bias",
"logit_bias_negative",
],
argnames=["expected_result", "logit_bias"],
argvalues=[(" a box of chocolates", None), (" riding a bicycle", {"▁a▁box▁of": -100.0})],
)
def test_completion_logit_bias__should_impact_on_response(expected_result: str, logit_bias: Dict[str, float]):
client = AI21Client()
response = client.completion.create(
prompt="Life is like",
max_tokens=3,
model="j2-ultra",
temperature=0,
logit_bias=logit_bias,
)

assert response.completions[0].data.text.strip() == expected_result.strip()