Skip to content

Commit 409e818

Browse files
authored
fix: Pass model id in Bedrock client init (#71)
* feat: Moved model_id in bedrock to init * test: Added tests to check model id * fix: bedrock example * refactor: rename test params
1 parent 8965828 commit 409e818

File tree

9 files changed

+65
-8
lines changed

9 files changed

+65
-8
lines changed

ai21/clients/bedrock/ai21_bedrock_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99

1010
class AI21BedrockClient:
1111
"""
12+
:param model_id: The model ID to use for the client.
1213
:param session: An optional boto3 session to use for the client.
1314
"""
1415

1516
def __init__(
1617
self,
18+
model_id: str,
1719
session: Optional[boto3.Session] = None,
1820
region: Optional[str] = None,
1921
env_config: _AI21EnvConfig = AI21EnvConfig,
2022
):
21-
2223
self._bedrock_session = BedrockSession(session=session, region=region or env_config.aws_region)
23-
self.completion = BedrockCompletion(self._bedrock_session)
24+
self.completion = BedrockCompletion(model_id=model_id, bedrock_session=self._bedrock_session)

ai21/clients/bedrock/resources/bedrock_completion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
class BedrockCompletion(BedrockResource):
88
def create(
99
self,
10-
model_id: str,
1110
prompt: str,
1211
*,
1312
max_tokens: Optional[int] = None,
@@ -42,6 +41,8 @@ def create(
4241
if count_penalty is not None:
4342
body["countPenalty"] = count_penalty.to_dict()
4443

44+
model_id = kwargs.get("model_id", self._model_id)
45+
4546
raw_response = self._invoke(model_id=model_id, body=body)
4647

4748
return CompletionsResponse.from_dict(raw_response)

ai21/clients/bedrock/resources/bedrock_resource.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77

88
class BedrockResource(ABC):
9-
def __init__(self, bedrock_session: BedrockSession):
9+
def __init__(self, model_id: str, bedrock_session: BedrockSession):
10+
self._model_id = model_id
1011
self._bedrock_session = bedrock_session
1112

1213
def _invoke(self, model_id: str, body: Dict[str, Any]) -> Dict[str, Any]:

examples/bedrock/completion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@
3838
"User: Hi, I have a question for you"
3939
)
4040

41-
response = AI21BedrockClient().completion.create(
41+
response = AI21BedrockClient(model_id=BedrockModelID.J2_MID_V1).completion.create(
4242
prompt=prompt,
4343
max_tokens=1000,
44-
model_id=BedrockModelID.J2_MID_V1,
4544
temperature=0,
4645
top_p=1,
4746
top_k_return=0,

tests/integration_tests/clients/bedrock/test_completion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,10 @@
5050
def test_completion__when_no_penalties__should_return_response(
5151
frequency_penalty: Optional[Penalty], presence_penalty: Optional[Penalty], count_penalty: Optional[Penalty]
5252
):
53-
client = AI21BedrockClient()
53+
client = AI21BedrockClient(model_id=BedrockModelID.J2_MID_V1)
5454
response = client.completion.create(
5555
prompt=_PROMPT,
5656
max_tokens=64,
57-
model_id=BedrockModelID.J2_MID_V1,
5857
temperature=0,
5958
top_p=1,
6059
top_k_return=0,

tests/unittests/clients/bedrock/__init__.py

Whitespace-only changes.

tests/unittests/clients/bedrock/resources/__init__.py

Whitespace-only changes.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from unittest.mock import Mock
2+
3+
import pytest
4+
from pytest_mock import MockerFixture
5+
6+
from ai21.clients.bedrock.bedrock_session import BedrockSession
7+
8+
9+
@pytest.fixture
10+
def mock_bedrock_session(mocker: MockerFixture) -> Mock:
11+
return mocker.MagicMock(spec=BedrockSession)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Optional
2+
3+
import pytest
4+
from pytest_mock import MockerFixture
5+
from unittest.mock import Mock, ANY
6+
7+
from ai21.clients.bedrock.resources.bedrock_completion import BedrockCompletion
8+
from ai21.models import Prompt
9+
10+
_CTOR_PROVIDED_MODEL_ID = "constructor_provided_model_id"
11+
_INVOCATION_MODEL_ID = "invocation_model_id"
12+
13+
14+
@pytest.mark.parametrize(
15+
ids=[
16+
"when_model_id_not_passed_in_create__should_use_model_id_from_init",
17+
"when_model_id_passed_in_create__should_use_model_id_from_create",
18+
],
19+
argnames=["invocation_model_id", "expected_model_id"],
20+
argvalues=[
21+
(None, _CTOR_PROVIDED_MODEL_ID),
22+
(_INVOCATION_MODEL_ID, _INVOCATION_MODEL_ID),
23+
],
24+
)
25+
def test__when_model_id_create_and_init__should_use_one_from_create(
26+
invocation_model_id: Optional[str],
27+
expected_model_id: str,
28+
mock_bedrock_session: Mock,
29+
mocker: MockerFixture,
30+
):
31+
mock_bedrock_session.invoke_model.return_value = {
32+
"id": expected_model_id,
33+
"prompt": mocker.MagicMock(spec=Prompt),
34+
"completions": [],
35+
}
36+
37+
client = BedrockCompletion(model_id=_CTOR_PROVIDED_MODEL_ID, bedrock_session=mock_bedrock_session)
38+
39+
# We can not pass "None" explicitly to the create method, so we have to use the if else statement
40+
if invocation_model_id is None:
41+
client.create(prompt="test")
42+
else:
43+
client.create(model_id=invocation_model_id, prompt="test")
44+
45+
mock_bedrock_session.invoke_model.assert_called_once_with(model_id=expected_model_id, input_json=ANY)

0 commit comments

Comments
 (0)