diff --git a/README.md b/README.md
index f61214e4..8c843a81 100644
--- a/README.md
+++ b/README.md
@@ -21,7 +21,6 @@
## Table of Contents
- [Examples](#examples-tldr) 🗂️
-- [Migration from v1.3.4 and below](#migration-from-v134-and-below)
- [AI21 Official Documentation](#Documentation)
- [Installation](#Installation) 💿
- [Usage - Chat Completions](#Usage)
@@ -29,7 +28,6 @@
- [Older Models Support Usage](#Older-Models-Support-Usage)
- [More Models](#More-Models)
- [Streaming](#Streaming)
- - [TSMs](#TSMs)
- [Token Counting](#Token-Counting)
- [Environment Variables](#Environment-Variables)
- [Error Handling](#Error-Handling)
@@ -48,101 +46,6 @@ If you want to quickly get a glance how to use the AI21 Python SDK and jump stra
Feel free to dive in, experiment, and adapt these examples to suit your needs. We believe they'll help you get up and running quickly.
-## Migration from v1.3.4 and below
-
-In `v2.0.0` we introduced a new SDK that is not backwards compatible with the previous version.
-This version allows for non-static instances of the client, defined parameters to each resource, modelized responses and
-more.
-
-
-Migration Examples
-
-### Instance creation (not available in v1.3.4 and below)
-
-```python
-from ai21 import AI21Client
-
-client = AI21Client(api_key='my_api_key')
-
-# or set api_key in environment variable - AI21_API_KEY and then
-client = AI21Client()
-```
-
-We No longer support static methods for each resource, instead we have a client instance that has a method for each
-allowing for more flexibility and better control.
-
-### Completion before/after
-
-```diff
-prompt = "some prompt"
-
-- import ai21
-- response = ai21.Completion.execute(model="j2-light", prompt=prompt, maxTokens=2)
-
-+ from ai21 import AI21Client
-+ client = ai21.AI21Client()
-+ response = client.completion.create(model="j2-light", prompt=prompt, max_tokens=2)
-```
-
-This applies to all resources. You would now need to create a client instance and use it to call the resource method.
-
-### Tokenization and Token counting before/after
-
-```diff
-- response = ai21.Tokenization.execute(text=prompt)
-- print(len(response)) # number of tokens
-
-+ from ai21 import AI21Client
-+ client = AI21Client()
-+ token_count = client.count_tokens(text=prompt)
-```
-
-### Key Access in Response Objects before/after
-
-It is no longer possible to access the response object as a dictionary. Instead, you can access the response object as an object with attributes.
-
-```diff
-- import ai21
-- response = ai21.Summarize.execute(source="some text", sourceType="TEXT")
-- response["summary"]
-
-+ from ai21 import AI21Client
-+ from ai21.models import DocumentType
-+ client = AI21Client()
-+ response = client.summarize.create(source="some text", source_type=DocumentType.TEXT)
-+ response.summary
-```
-
----
-
-### AWS Client Creations
-
-### Bedrock Client creation before/after
-
-```diff
-- import ai21
-- destination = ai21.BedrockDestination(model_id=ai21.BedrockModelID.J2_MID_V1)
-- response = ai21.Completion.execute(prompt=prompt, maxTokens=1000, destination=destination)
-
-+ from ai21 import AI21BedrockClient, BedrockModelID
-+ client = AI21BedrockClient()
-+ response = client.completion.create(prompt=prompt, max_tokens=1000, model_id=BedrockModelID.J2_MID_V1)
-```
-
-### SageMaker Client creation before/after
-
-```diff
-- import ai21
-- destination = ai21.SageMakerDestination("j2-mid-test-endpoint")
-- response = ai21.Completion.execute(prompt=prompt, maxTokens=1000, destination=destination)
-
-+ from ai21 import AI21SageMakerClient
-+ client = AI21SageMakerClient(endpoint_name="j2-mid-test-endpoint")
-+ response = client.completion.create(prompt=prompt, max_tokens=1000)
-```
-
-
-
## Documentation
---
@@ -220,20 +123,6 @@ asyncio.run(main())
A more detailed example can be found [here](examples/studio/chat/chat_completions.py).
-## Older Models Support Usage
-
-
-Examples
-
-### Supported Models:
-
-- j2-light
-- [j2-ultra](#Chat)
-- [j2-mid](#Completion)
-- [jamba-instruct](#Chat-Completion)
-
-you can read more about the models [here](https://docs.ai21.com/reference/j2-complete-api-ref#jurassic-2-models).
-
### Chat
```python
diff --git a/ai21/clients/bedrock/ai21_bedrock_client.py b/ai21/clients/bedrock/ai21_bedrock_client.py
index d2cf5cfe..78b32e35 100644
--- a/ai21/clients/bedrock/ai21_bedrock_client.py
+++ b/ai21/clients/bedrock/ai21_bedrock_client.py
@@ -10,10 +10,6 @@
from ai21.clients.aws.aws_authorization import AWSAuthorization
from ai21.clients.bedrock._stream_decoder import _AWSEventStreamDecoder
from ai21.clients.studio.resources.studio_chat import AsyncStudioChat, StudioChat
-from ai21.clients.studio.resources.studio_completion import (
- AsyncStudioCompletion,
- StudioCompletion,
-)
from ai21.errors import AccessDenied, APITimeoutError, ModelErrorException, NotFound
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient
@@ -101,10 +97,6 @@ def __init__(
BaseBedrockClient.__init__(self, session=session, region=self._region)
self.chat = StudioChat(self)
- # Override the chat.create method to match the completions endpoint,
- # so it wouldn't get to the old J2 completion endpoint
- self.chat.create = self.chat.completions.create
- self.completion = StudioCompletion(self)
def _build_request(self, options: RequestOptions) -> httpx.Request:
options = self._prepare_options(options)
@@ -146,10 +138,7 @@ def __init__(
BaseBedrockClient.__init__(self, session=session, region=self._region)
self.chat = AsyncStudioChat(self)
- # Override the chat.create method to match the completions endpoint,
- # so it wouldn't get to the old J2 completion endpoint
self.chat.create = self.chat.completions.create
- self.completion = AsyncStudioCompletion(self)
def _build_request(self, options: RequestOptions) -> httpx.Request:
options = self._prepare_options(options)
diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py
deleted file mode 100644
index 2e39e432..00000000
--- a/ai21/clients/common/completion_base.py
+++ /dev/null
@@ -1,96 +0,0 @@
-from __future__ import annotations
-
-from abc import ABC, abstractmethod
-from typing import Dict, List
-
-from ai21.models import CompletionsResponse, Penalty
-from ai21.models._pydantic_compatibility import _to_dict
-from ai21.types import NOT_GIVEN, NotGiven
-from ai21.utils.typing import remove_not_given
-
-
-class Completion(ABC):
- _module_name = "complete"
-
- @abstractmethod
- def create(
- self,
- model: str,
- prompt: str,
- *,
- max_tokens: int | NotGiven = NOT_GIVEN,
- num_results: int | NotGiven = NOT_GIVEN,
- min_tokens: int | NotGiven = NOT_GIVEN,
- temperature: float | NOT_GIVEN = NOT_GIVEN,
- top_p: float | NotGiven = NOT_GIVEN,
- top_k_return: int | NotGiven = NOT_GIVEN,
- stop_sequences: List[str] | NotGiven = NOT_GIVEN,
- frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
- presence_penalty: Penalty | NotGiven = NOT_GIVEN,
- count_penalty: Penalty | NotGiven = NOT_GIVEN,
- epoch: int | NotGiven = NOT_GIVEN,
- logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
- **kwargs,
- ) -> CompletionsResponse:
- """
- :param model: model type you wish to interact with
- :param prompt: Text for model to complete
- :param max_tokens: The maximum number of tokens to generate per result
- :param num_results: Number of completions to sample and return.
- :param min_tokens: The minimum number of tokens to generate per result.
- :param temperature: A value controlling the "creativity" of the model's responses.
- :param top_p: A value controlling the diversity of the model's responses.
- :param top_k_return: The number of top-scoring tokens to consider for each generation step.
- :param stop_sequences: Stops decoding if any of the strings is generated
- :param frequency_penalty: A penalty applied to tokens that are frequently generated.
- :param presence_penalty: A penalty applied to tokens that are already present in the prompt.
- :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
- :param epoch:
- :param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text
- representations of the tokens and the floats are the biases themselves. A positive bias increases generation
- probability for a given token and a negative bias decreases it.
- :param kwargs:
- :return:
- """
- pass
-
- def _create_body(
- self,
- model: str,
- prompt: str,
- max_tokens: int | NotGiven,
- num_results: int | NotGiven,
- min_tokens: int | NotGiven,
- temperature: float | NotGiven,
- top_p: float | NotGiven,
- top_k_return: int | NotGiven,
- stop_sequences: List[str] | NotGiven,
- frequency_penalty: Penalty | NotGiven,
- presence_penalty: Penalty | NotGiven,
- count_penalty: Penalty | NotGiven,
- epoch: int | NotGiven,
- logit_bias: Dict[str, float] | NotGiven,
- **kwargs,
- ):
- return remove_not_given(
- {
- "model": model,
- "prompt": prompt,
- "maxTokens": max_tokens,
- "numResults": num_results,
- "minTokens": min_tokens,
- "temperature": temperature,
- "topP": top_p,
- "topKReturn": top_k_return,
- "stopSequences": stop_sequences,
- "frequencyPenalty": (NOT_GIVEN if frequency_penalty is NOT_GIVEN else _to_dict(frequency_penalty)),
- "presencePenalty": (NOT_GIVEN if presence_penalty is NOT_GIVEN else _to_dict(presence_penalty)),
- "countPenalty": (NOT_GIVEN if count_penalty is NOT_GIVEN else _to_dict(count_penalty)),
- "epoch": epoch,
- "logitBias": logit_bias,
- **kwargs,
- }
- )
-
- def _get_completion_path(self, model: str):
- return f"/{model}/{self._module_name}"
diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py
deleted file mode 100644
index 82f3acf6..00000000
--- a/ai21/clients/studio/resources/studio_completion.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from __future__ import annotations
-
-from typing import Dict, List
-
-from ai21.clients.common.completion_base import Completion
-from ai21.clients.studio.resources.studio_resource import (
- AsyncStudioResource,
- StudioResource,
-)
-from ai21.models import CompletionsResponse, Penalty
-from ai21.types import NOT_GIVEN, NotGiven
-
-
-class StudioCompletion(StudioResource, Completion):
- def create(
- self,
- prompt: str,
- model: str,
- *,
- max_tokens: int | NotGiven = NOT_GIVEN,
- num_results: int | NotGiven = NOT_GIVEN,
- min_tokens: int | NotGiven = NOT_GIVEN,
- temperature: float | NotGiven = NOT_GIVEN,
- top_p: float | NotGiven = NOT_GIVEN,
- top_k_return: int | NotGiven = NOT_GIVEN,
- stop_sequences: List[str] | NotGiven = NOT_GIVEN,
- frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
- presence_penalty: Penalty | NotGiven = NOT_GIVEN,
- count_penalty: Penalty | NotGiven = NOT_GIVEN,
- epoch: int | NotGiven = NOT_GIVEN,
- logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
- **kwargs,
- ) -> CompletionsResponse:
- path = self._get_completion_path(model=model)
- body = self._create_body(
- model=model,
- prompt=prompt,
- max_tokens=max_tokens,
- num_results=num_results,
- min_tokens=min_tokens,
- temperature=temperature,
- top_p=top_p,
- top_k_return=top_k_return,
- stop_sequences=stop_sequences,
- frequency_penalty=frequency_penalty,
- presence_penalty=presence_penalty,
- count_penalty=count_penalty,
- epoch=epoch,
- logit_bias=logit_bias,
- **kwargs,
- )
- return self._post(path=path, body=body, response_cls=CompletionsResponse)
-
-
-class AsyncStudioCompletion(AsyncStudioResource, Completion):
- async def create(
- self,
- prompt: str,
- model: str,
- *,
- max_tokens: int | NotGiven = NOT_GIVEN,
- num_results: int | NotGiven = NOT_GIVEN,
- min_tokens: int | NotGiven = NOT_GIVEN,
- temperature: float | NotGiven = NOT_GIVEN,
- top_p: float | NotGiven = NOT_GIVEN,
- top_k_return: int | NotGiven = NOT_GIVEN,
- stop_sequences: List[str] | NotGiven = NOT_GIVEN,
- frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
- presence_penalty: Penalty | NotGiven = NOT_GIVEN,
- count_penalty: Penalty | NotGiven = NOT_GIVEN,
- epoch: int | NotGiven = NOT_GIVEN,
- logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
- **kwargs,
- ) -> CompletionsResponse:
- path = self._get_completion_path(model=model)
- body = self._create_body(
- model=model,
- prompt=prompt,
- max_tokens=max_tokens,
- num_results=num_results,
- min_tokens=min_tokens,
- temperature=temperature,
- top_p=top_p,
- top_k_return=top_k_return,
- stop_sequences=stop_sequences,
- frequency_penalty=frequency_penalty,
- presence_penalty=presence_penalty,
- count_penalty=count_penalty,
- epoch=epoch,
- logit_bias=logit_bias,
- **kwargs,
- )
-
- return await self._post(path=path, body=body, response_cls=CompletionsResponse)
diff --git a/ai21/models/__init__.py b/ai21/models/__init__.py
index 7d45c031..c7f1106e 100644
--- a/ai21/models/__init__.py
+++ b/ai21/models/__init__.py
@@ -2,17 +2,14 @@
from ai21.models.chat_message import ChatMessage
from ai21.models.document_type import DocumentType
from ai21.models.penalty import Penalty
-from ai21.models.responses.chat_response import ChatResponse, ChatOutput, FinishReason
-from ai21.models.responses.completion_response import (
- CompletionsResponse,
- Completion,
- CompletionFinishReason,
- CompletionData,
- Prompt,
+from ai21.models.responses.chat_response import ChatOutput, ChatResponse, FinishReason
+from ai21.models.responses.conversational_rag_response import (
+ ConversationalRagResponse,
+ ConversationalRagSource,
)
-from ai21.models.responses.conversational_rag_response import ConversationalRagResponse, ConversationalRagSource
from ai21.models.responses.file_response import FileResponse
+
__all__ = [
"ChatMessage",
"RoleType",
@@ -21,11 +18,6 @@
"ChatResponse",
"ChatOutput",
"FinishReason",
- "CompletionsResponse",
- "Completion",
- "CompletionFinishReason",
- "CompletionData",
- "Prompt",
"FileResponse",
"ConversationalRagResponse",
"ConversationalRagSource",
diff --git a/ai21/models/responses/completion_response.py b/ai21/models/responses/completion_response.py
deleted file mode 100644
index 15c2c80a..00000000
--- a/ai21/models/responses/completion_response.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from typing import List, Optional, Union, Any, Dict
-
-from pydantic import Field
-
-from ai21.models.ai21_base_model import AI21BaseModel
-
-
-class Prompt(AI21BaseModel):
- text: Optional[str]
- tokens: Optional[List[Dict[str, Any]]] = None
-
-
-class CompletionData(AI21BaseModel):
- text: Optional[str]
- tokens: Optional[List[Dict[str, Any]]] = None
-
-
-class CompletionFinishReason(AI21BaseModel):
- reason: Optional[str] = None
- length: Optional[int] = None
-
-
-class Completion(AI21BaseModel):
- data: CompletionData
- finish_reason: Optional[CompletionFinishReason] = Field(default=None, alias="finishReason")
-
-
-class CompletionsResponse(AI21BaseModel):
- id: Union[int, str]
- prompt: Prompt
- completions: List[Completion]
diff --git a/examples/bedrock/async_completion.py b/examples/bedrock/async_completion.py
deleted file mode 100644
index eb3fd4d3..00000000
--- a/examples/bedrock/async_completion.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import asyncio
-from ai21 import AsyncAI21BedrockClient, BedrockModelID
-
-# Bedrock is currently supported only in us-east-1 region.
-# Either set your profile's region to us-east-1 or uncomment next line
-# ai21.aws_region = 'us-east-1'
-# Or create a boto session and pass it:
-# import boto3
-# session = boto3.Session(region_name="us-east-1")
-
-prompt = (
- "The following is a conversation between a user of an eCommerce store and a user operation"
- " associate called Max. Max is very kind and keen to help."
- " The following are important points about the business policies:\n- "
- "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:"
- " Male.\n\nConversation:\nUser: Hi, had a question\nMax: "
- "Hi there, happy to help!\nUser: Is there no way to return a product?"
- " I got your blue T-Shirt size small but it doesn't fit.\n"
- "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n"
- "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n"
- "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation"
- " associate called Max. Max is very kind and keen to help. The following are important points about"
- " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n"
- 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" '
- "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me"
- " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n"
- "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between"
- " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help."
- " The following are important points about the business policies:\n- Delivery takes up to 5 days\n"
- "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it"
- " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working"
- " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n"
- "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n"
- "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n"
- "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an"
- " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following"
- " are important points about the business policies:\n- Delivery takes up to 5 days\n"
- "- There is no return option\n\nUser gender: Female.\n\nConversation:\n"
- "User: Hi, I have a question for you"
-)
-
-
-async def main():
- response = await AsyncAI21BedrockClient().completion.create(
- prompt=prompt,
- max_tokens=50,
- temperature=0,
- top_p=1,
- top_k_return=0,
- model=BedrockModelID.J2_ULTRA_V1,
- )
-
- print(response.completions[0].data.text)
- print(response.prompt.tokens[0]["textRange"]["start"])
-
-
-asyncio.run(main())
diff --git a/examples/bedrock/completion.py b/examples/bedrock/completion.py
deleted file mode 100644
index 2e60314d..00000000
--- a/examples/bedrock/completion.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from ai21 import AI21BedrockClient, BedrockModelID
-
-# Bedrock is currently supported only in us-east-1 region.
-# Either set your profile's region to us-east-1 or uncomment next line
-# ai21.aws_region = 'us-east-1'
-# Or create a boto session and pass it:
-# import boto3
-# session = boto3.Session(region_name="us-east-1")
-
-prompt = (
- "The following is a conversation between a user of an eCommerce store and a user operation"
- " associate called Max. Max is very kind and keen to help."
- " The following are important points about the business policies:\n- "
- "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:"
- " Male.\n\nConversation:\nUser: Hi, had a question\nMax: "
- "Hi there, happy to help!\nUser: Is there no way to return a product?"
- " I got your blue T-Shirt size small but it doesn't fit.\n"
- "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n"
- "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n"
- "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation"
- " associate called Max. Max is very kind and keen to help. The following are important points about"
- " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n"
- 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" '
- "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me"
- " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n"
- "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between"
- " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help."
- " The following are important points about the business policies:\n- Delivery takes up to 5 days\n"
- "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it"
- " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working"
- " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n"
- "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n"
- "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n"
- "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an"
- " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following"
- " are important points about the business policies:\n- Delivery takes up to 5 days\n"
- "- There is no return option\n\nUser gender: Female.\n\nConversation:\n"
- "User: Hi, I have a question for you"
-)
-
-response = AI21BedrockClient().completion.create(
- prompt=prompt, max_tokens=50, temperature=0, top_p=1, top_k_return=0, model=BedrockModelID.J2_ULTRA_V1
-)
-
-print(response.completions[0].data.text)
-print(response.prompt.tokens[0]["textRange"]["start"])
diff --git a/tests/integration_tests/clients/bedrock/test_completion.py b/tests/integration_tests/clients/bedrock/test_completion.py
deleted file mode 100644
index edf1bc00..00000000
--- a/tests/integration_tests/clients/bedrock/test_completion.py
+++ /dev/null
@@ -1,128 +0,0 @@
-from typing import Optional
-
-import pytest
-
-from ai21 import AI21BedrockClient, AsyncAI21BedrockClient
-from ai21.clients.bedrock.bedrock_model_id import BedrockModelID
-from ai21.models import Penalty
-from tests.integration_tests.skip_helpers import should_skip_bedrock_integration_tests
-
-
-_PROMPT = "Once upon a time, in a land far, far away, there was a"
-
-
-@pytest.mark.skipif(should_skip_bedrock_integration_tests(), reason="No keys supplied for AWS. Skipping.")
-@pytest.mark.parametrize(
- ids=[
- "when_no_penalties__should_return_response",
- "when_penalties__should_return_response",
- ],
- argnames=["frequency_penalty", "presence_penalty", "count_penalty"],
- argvalues=[
- (None, None, None),
- (
- Penalty(scale=0.5),
- Penalty(
- scale=0.5,
- apply_to_emojis=True,
- ),
- Penalty(
- scale=0.5,
- apply_to_emojis=True,
- apply_to_numbers=True,
- apply_to_stopwords=True,
- apply_to_punctuation=True,
- apply_to_whitespaces=True,
- ),
- ),
- ],
-)
-def test_completion_penalties__should_return_response(
- frequency_penalty: Optional[Penalty], presence_penalty: Optional[Penalty], count_penalty: Optional[Penalty]
-):
- client = AI21BedrockClient()
- completion_args = dict(
- prompt=_PROMPT,
- max_tokens=64,
- model=BedrockModelID.J2_ULTRA_V1,
- temperature=0,
- top_p=1,
- top_k_return=0,
- )
-
- for arg_name, penalty in [
- ("frequency_penalty", frequency_penalty),
- ("presence_penalty", presence_penalty),
- ("count_penalty", count_penalty),
- ]:
- if penalty:
- completion_args[arg_name] = penalty
-
- response = client.completion.create(**completion_args)
-
- assert response.prompt.text == _PROMPT
- assert len(response.completions) == 1
-
- # Check the results aren't all the same
- assert len([completion.data.text for completion in response.completions]) == 1
- for completion in response.completions:
- assert isinstance(completion.data.text, str)
-
-
-@pytest.mark.asyncio
-@pytest.mark.skipif(should_skip_bedrock_integration_tests(), reason="No keys supplied for AWS. Skipping.")
-@pytest.mark.parametrize(
- ids=[
- "when_no_penalties__should_return_response",
- "when_penalties__should_return_response",
- ],
- argnames=["frequency_penalty", "presence_penalty", "count_penalty"],
- argvalues=[
- (None, None, None),
- (
- Penalty(scale=0.5),
- Penalty(
- scale=0.5,
- apply_to_emojis=True,
- ),
- Penalty(
- scale=0.5,
- apply_to_emojis=True,
- apply_to_numbers=True,
- apply_to_stopwords=True,
- apply_to_punctuation=True,
- apply_to_whitespaces=True,
- ),
- ),
- ],
-)
-async def test_async_completion_penalties__should_return_response(
- frequency_penalty: Optional[Penalty], presence_penalty: Optional[Penalty], count_penalty: Optional[Penalty]
-):
- client = AsyncAI21BedrockClient()
- completion_args = dict(
- prompt=_PROMPT,
- max_tokens=64,
- model=BedrockModelID.J2_ULTRA_V1,
- temperature=0,
- top_p=1,
- top_k_return=0,
- )
-
- for arg_name, penalty in [
- ("frequency_penalty", frequency_penalty),
- ("presence_penalty", presence_penalty),
- ("count_penalty", count_penalty),
- ]:
- if penalty:
- completion_args[arg_name] = penalty
-
- response = await client.completion.create(**completion_args)
-
- assert response.prompt.text == _PROMPT
- assert len(response.completions) == 1
-
- # Check the results aren't all the same
- assert len([completion.data.text for completion in response.completions]) == 1
- for completion in response.completions:
- assert isinstance(completion.data.text, str)
diff --git a/tests/integration_tests/clients/test_bedrock.py b/tests/integration_tests/clients/test_bedrock.py
index 79298ed4..3a00b4b8 100644
--- a/tests/integration_tests/clients/test_bedrock.py
+++ b/tests/integration_tests/clients/test_bedrock.py
@@ -3,12 +3,14 @@
"""
import subprocess
+
from pathlib import Path
import pytest
from tests.integration_tests.skip_helpers import should_skip_bedrock_integration_tests
+
BEDROCK_DIR = "bedrock"
BEDROCK_PATH = Path(__file__).parent.parent.parent.parent / "examples" / BEDROCK_DIR
@@ -18,16 +20,12 @@
@pytest.mark.parametrize(
argnames=["test_file_name"],
argvalues=[
- ("completion.py",),
- ("async_completion.py",),
("chat/chat_completions.py",),
("chat/stream_chat_completions.py",),
("chat/async_chat_completions.py",),
("chat/async_stream_chat_completions.py",),
],
ids=[
- "when_completion__should_return_ok",
- "when_async_completion__should_return_ok",
"when_chat_completions__should_return_ok",
"when_stream_chat_completions__should_return_ok",
"when_async_chat_completions__should_return_ok",
diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py
index 1e8413b5..eff87271 100644
--- a/tests/unittests/clients/studio/resources/conftest.py
+++ b/tests/unittests/clients/studio/resources/conftest.py
@@ -1,25 +1,18 @@
import httpx
import pytest
+
from pytest_mock import MockerFixture
-from ai21.clients.studio.resources.chat import AsyncChatCompletions
-from ai21.clients.studio.resources.chat import ChatCompletions
-from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat
-from ai21.clients.studio.resources.studio_completion import StudioCompletion, AsyncStudioCompletion
+from ai21.clients.studio.resources.chat import AsyncChatCompletions, ChatCompletions
+from ai21.clients.studio.resources.studio_chat import AsyncStudioChat, StudioChat
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient
-from ai21.models import (
- ChatMessage,
- RoleType,
- ChatResponse,
- CompletionsResponse,
-)
-from ai21.models._pydantic_compatibility import _to_dict, _from_dict
+from ai21.models import ChatMessage, ChatResponse, RoleType
+from ai21.models._pydantic_compatibility import _from_dict, _to_dict
from ai21.models.chat import (
- ChatMessage as ChatCompletionChatMessage,
ChatCompletionResponse,
+ ChatMessage as ChatCompletionChatMessage,
)
-from ai21.utils.typing import to_lower_camel_case
@pytest.fixture
@@ -134,33 +127,3 @@ def get_chat_completions(is_async: bool = False):
httpx.Response(status_code=200, json=json_response),
_from_dict(obj=ChatCompletionResponse, obj_dict=json_response),
)
-
-
-def get_studio_completion(is_async: bool = True, **kwargs):
- _DUMMY_MODEL = "dummy-completion-model"
- _DUMMY_PROMPT = "dummy-prompt"
- json_response = {
- "id": "some-id",
- "completions": [
- {
- "data": {"text": "dummy-completion", "tokens": []},
- "finishReason": {"reason": "dummy_reason", "length": 1},
- }
- ],
- "prompt": {"text": "dummy-prompt"},
- }
-
- resource = AsyncStudioCompletion if is_async else StudioCompletion
-
- return (
- resource,
- {"model": _DUMMY_MODEL, "prompt": _DUMMY_PROMPT, **kwargs},
- f"{_DUMMY_MODEL}/complete",
- {
- "model": _DUMMY_MODEL,
- "prompt": _DUMMY_PROMPT,
- **{to_lower_camel_case(k): v for k, v in kwargs.items()},
- },
- httpx.Response(status_code=200, json=json_response),
- _from_dict(obj=CompletionsResponse, obj_dict=json_response),
- )
diff --git a/tests/unittests/clients/studio/resources/test_async_studio_resource.py b/tests/unittests/clients/studio/resources/test_async_studio_resource.py
index 0897479c..d86a6598 100644
--- a/tests/unittests/clients/studio/resources/test_async_studio_resource.py
+++ b/tests/unittests/clients/studio/resources/test_async_studio_resource.py
@@ -1,4 +1,4 @@
-from typing import TypeVar, Callable
+from typing import Callable, TypeVar
import pytest
@@ -6,11 +6,11 @@
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.models.ai21_base_model import AI21BaseModel
from tests.unittests.clients.studio.resources.conftest import (
- get_studio_chat,
get_chat_completions,
- get_studio_completion,
+ get_studio_chat,
)
+
_BASE_URL = "https://test.api.ai21.com/studio/v1"
T = TypeVar("T", bound=AsyncStudioResource)
@@ -22,8 +22,6 @@ class TestAsyncStudioResources:
ids=[
"async_studio_chat",
"async_chat_completions",
- "async_studio_completion",
- "async_studio_completion_with_extra_args",
],
argnames=[
"studio_resource",
@@ -36,8 +34,6 @@ class TestAsyncStudioResources:
argvalues=[
(get_studio_chat(is_async=True)),
(get_chat_completions(is_async=True)),
- (get_studio_completion(is_async=True)),
- (get_studio_completion(is_async=True, temperature=0.5, max_tokens=50)),
],
)
async def test__create__should_return_response(
diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py
index 8f3cdd7c..0db9f8bb 100644
--- a/tests/unittests/clients/studio/resources/test_studio_resources.py
+++ b/tests/unittests/clients/studio/resources/test_studio_resources.py
@@ -1,15 +1,16 @@
-from typing import TypeVar, Callable
+from typing import Callable, TypeVar
import pytest
-from ai21.http_client.http_client import AI21HTTPClient
+
from ai21.clients.studio.resources.studio_resource import StudioResource
+from ai21.http_client.http_client import AI21HTTPClient
from ai21.models.ai21_base_model import AI21BaseModel
from tests.unittests.clients.studio.resources.conftest import (
- get_studio_chat,
- get_studio_completion,
get_chat_completions,
+ get_studio_chat,
)
+
_BASE_URL = "https://test.api.ai21.com/studio/v1"
T = TypeVar("T", bound=StudioResource)
@@ -19,8 +20,6 @@ class TestStudioResources:
ids=[
"studio_chat",
"chat_completions",
- "studio_completion",
- "studio_completion_with_extra_args",
],
argnames=[
"studio_resource",
@@ -33,8 +32,6 @@ class TestStudioResources:
argvalues=[
(get_studio_chat()),
(get_chat_completions()),
- (get_studio_completion(is_async=False)),
- (get_studio_completion(is_async=False, temperature=0.5, max_tokens=50)),
],
)
def test__create__should_return_response(
diff --git a/tests/unittests/models/response_mocks.py b/tests/unittests/models/response_mocks.py
index ffb55ad2..5c4ad2fe 100644
--- a/tests/unittests/models/response_mocks.py
+++ b/tests/unittests/models/response_mocks.py
@@ -1,14 +1,4 @@
-from ai21.models import (
- ChatResponse,
- ChatOutput,
- RoleType,
- FinishReason,
- CompletionsResponse,
- Prompt,
- Completion,
- CompletionFinishReason,
- CompletionData,
-)
+from ai21.models import ChatOutput, ChatResponse, FinishReason, RoleType
from ai21.models.chat import ChatCompletionResponse, ChatCompletionResponseChoice
from ai21.models.chat.chat_message import AssistantMessage
from ai21.models.usage_info import UsageInfo
@@ -74,79 +64,3 @@ def get_chat_completions_response():
)
return chat_completions_response, expected_dict, ChatCompletionResponse
-
-
-def get_completions_response():
- expected_dict = {
- "id": "123-abc",
- "prompt": {
- "text": "life is like ",
- "tokens": [
- {
- "generatedToken": {
- "token": "▁life▁is",
- "logprob": -14.273218154907227,
- "raw_logprob": -14.273218154907227,
- },
- "topTokens": None,
- "textRange": {"start": 0, "end": 7},
- },
- ],
- },
- "completions": [
- {
- "data": {
- "text": "\nlife is like a journey, full of ups and downs, twists and turns. It is unpredictable "
- "and can be challenging, but it is also",
- "tokens": [
- {
- "generatedToken": {
- "token": "<|newline|>",
- "logprob": -0.006884856149554253,
- "raw_logprob": -0.12210073322057724,
- },
- "topTokens": None,
- "textRange": {"start": 0, "end": 1},
- },
- ],
- },
- "finishReason": {"reason": "length", "length": 16},
- }
- ],
- }
-
- prompt = Prompt(
- text="life is like ",
- tokens=[
- {
- "generatedToken": {
- "token": "▁life▁is",
- "logprob": -14.273218154907227,
- "raw_logprob": -14.273218154907227,
- },
- "topTokens": None,
- "textRange": {"start": 0, "end": 7},
- },
- ],
- )
- completion = Completion(
- data=CompletionData(
- text="\nlife is like a journey, full of ups and downs, twists and turns. It is unpredictable and can be "
- "challenging, but it is also",
- tokens=[
- {
- "generatedToken": {
- "token": "<|newline|>",
- "logprob": -0.006884856149554253,
- "raw_logprob": -0.12210073322057724,
- },
- "topTokens": None,
- "textRange": {"start": 0, "end": 1},
- }
- ],
- ),
- finish_reason=CompletionFinishReason(reason="length", length=16),
- )
- completion_response = CompletionsResponse(id="123-abc", prompt=prompt, completions=[completion])
-
- return completion_response, expected_dict, CompletionsResponse
diff --git a/tests/unittests/models/test_serialization.py b/tests/unittests/models/test_serialization.py
index 35473cd0..53f7fc16 100644
--- a/tests/unittests/models/test_serialization.py
+++ b/tests/unittests/models/test_serialization.py
@@ -1,14 +1,13 @@
-from typing import Dict, Any
+from typing import Any, Dict
import pytest
from ai21.models import Penalty
-from ai21.models._pydantic_compatibility import _to_dict, _from_dict
+from ai21.models._pydantic_compatibility import _from_dict, _to_dict
from ai21.models.ai21_base_model import IS_PYDANTIC_V2, AI21BaseModel
from tests.unittests.models.response_mocks import (
- get_chat_response,
get_chat_completions_response,
- get_completions_response,
+ get_chat_response,
)
@@ -43,7 +42,6 @@ def test_penalty__from_json__should_return_instance_with_given_values():
ids=[
"chat_response",
"chat_completions_response",
- "completion_response",
],
argnames=[
"response_obj",
@@ -53,7 +51,6 @@ def test_penalty__from_json__should_return_instance_with_given_values():
argvalues=[
(get_chat_response()),
(get_chat_completions_response()),
- (get_completions_response()),
],
)
def test_to_dict__should_serialize_to_dict__(
@@ -67,7 +64,6 @@ def test_to_dict__should_serialize_to_dict__(
ids=[
"chat_response",
"chat_completions_response",
- "completion_response",
],
argnames=[
"response_obj",
@@ -77,7 +73,6 @@ def test_to_dict__should_serialize_to_dict__(
argvalues=[
(get_chat_response()),
(get_chat_completions_response()),
- (get_completions_response()),
],
)
def test_from_dict__should_serialize_from_dict__(