Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
name: Integration Tests

on:
push:
branches:
- main
- "rc_*"
on: [push]

env:
POETRY_VERSION: "1.7.1"
Expand Down
13 changes: 10 additions & 3 deletions ai21/clients/bedrock/resources/bedrock_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,17 @@ def create(
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences or [],
"frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(),
"presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(),
"countPenalty": None if count_penalty is None else count_penalty.to_dict(),
}

if frequency_penalty is not None:
body["frequencyPenalty"] = frequency_penalty.to_dict()

if presence_penalty is not None:
body["presencePenalty"] = presence_penalty.to_dict()

if count_penalty is not None:
body["countPenalty"] = count_penalty.to_dict()

raw_response = self._invoke(model_id=model_id, body=body)

return CompletionsResponse.from_dict(raw_response)
13 changes: 10 additions & 3 deletions ai21/clients/sagemaker/resources/sagemaker_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,17 @@ def create(
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences or [],
"frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(),
"presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(),
"countPenalty": None if count_penalty is None else count_penalty.to_dict(),
}

if frequency_penalty is not None:
body["frequencyPenalty"] = frequency_penalty.to_dict()

if presence_penalty is not None:
body["presencePenalty"] = presence_penalty.to_dict()

if count_penalty is not None:
body["countPenalty"] = count_penalty.to_dict()

raw_response = self._invoke(body)

return CompletionsResponse.from_dict(raw_response)
29 changes: 0 additions & 29 deletions examples/bedrock/completion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from ai21 import AI21BedrockClient, BedrockModelID
from ai21.models import Penalty

# Bedrock is currently supported only in us-east-1 region.
# Either set your profile's region to us-east-1 or uncomment next line
Expand Down Expand Up @@ -46,34 +45,6 @@
temperature=0,
top_p=1,
top_k_return=0,
stop_sequences=["##"],
num_results=1,
custom_model=None,
epoch=1,
count_penalty=Penalty(
scale=0,
apply_to_emojis=False,
apply_to_numbers=False,
apply_to_stopwords=False,
apply_to_punctuation=False,
apply_to_whitespaces=False,
),
frequency_penalty=Penalty(
scale=0,
apply_to_emojis=False,
apply_to_numbers=False,
apply_to_stopwords=False,
apply_to_punctuation=False,
apply_to_whitespaces=False,
),
presence_penalty=Penalty(
scale=0,
apply_to_emojis=False,
apply_to_numbers=False,
apply_to_stopwords=False,
apply_to_punctuation=False,
apply_to_whitespaces=False,
),
)

print(response.completions[0].data.text)
Expand Down
3 changes: 0 additions & 3 deletions examples/studio/answer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from ai21 import AI21Client
from ai21.models import Mode, AnswerLength


client = AI21Client()
Expand All @@ -10,7 +9,5 @@
"ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and "
"economic power, dominating the other provinces of the newly independent Dutch Republic.",
question="When did Holland become an economic power?",
answer_length=AnswerLength.LONG,
mode=Mode.FLEXIBLE,
)
print(response)
Empty file.
71 changes: 71 additions & 0 deletions tests/integration_tests/clients/bedrock/test_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Optional

import pytest

from ai21 import AI21BedrockClient
from ai21.clients.bedrock.bedrock_model_id import BedrockModelID
from ai21.models import Penalty
from tests.integration_tests.skip_helpers import should_skip_bedrock_integration_tests

_PROMPT = "Once upon a time, in a land far, far away, there was a"


@pytest.mark.skipif(should_skip_bedrock_integration_tests(), reason="No keys supplied for AWS. Skipping.")
@pytest.mark.parametrize(
ids=[
"when_no_penalties__should_return_response",
"when_penalties__should_return_response",
],
argnames=["frequency_penalty", "presence_penalty", "count_penalty"],
argvalues=[
(None, None, None),
(
Penalty(
scale=0.5,
apply_to_emojis=True,
apply_to_numbers=True,
apply_to_stopwords=True,
apply_to_punctuation=True,
apply_to_whitespaces=True,
),
Penalty(
scale=0.5,
apply_to_emojis=True,
apply_to_numbers=True,
apply_to_stopwords=True,
apply_to_punctuation=True,
apply_to_whitespaces=True,
),
Penalty(
scale=0.5,
apply_to_emojis=True,
apply_to_numbers=True,
apply_to_stopwords=True,
apply_to_punctuation=True,
apply_to_whitespaces=True,
),
),
],
)
def test_completion__when_no_penalties__should_return_response(
frequency_penalty: Optional[Penalty], presence_penalty: Optional[Penalty], count_penalty: Optional[Penalty]
):
client = AI21BedrockClient()
response = client.completion.create(
prompt=_PROMPT,
max_tokens=64,
model_id=BedrockModelID.J2_MID_V1,
temperature=0,
top_p=1,
top_k_return=0,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
count_penalty=count_penalty,
)

assert response.prompt.text == _PROMPT
assert len(response.completions) == 1
# Check the results aren't all the same
assert len([completion.data.text for completion in response.completions]) == 1
for completion in response.completions:
assert isinstance(completion.data.text, str)