Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
75e4d4e
fix: types
asafgardin Jan 24, 2024
7242325
test: Added some integration tests
asafgardin Jan 24, 2024
1525a0e
test: improvements
asafgardin Jan 25, 2024
cc7fbe1
test: test_paraphrase.py
asafgardin Jan 27, 2024
52f300e
fix: doc
asafgardin Jan 27, 2024
3d54cde
fix: removed unused comment
asafgardin Jan 27, 2024
0197c8c
test: test_summarize.py
asafgardin Jan 27, 2024
d9c148e
test: Added tests for test_summarize_by_segment.py
asafgardin Jan 28, 2024
e644bd2
test: test_segmentation.py
asafgardin Jan 28, 2024
1adb281
fix: file id in library response
asafgardin Jan 28, 2024
3ad97e4
fix: example for library
asafgardin Jan 28, 2024
0d03b19
Merge branch 'rc_2_0_0' into integration_tests
asafgardin Jan 29, 2024
0eacdbb
ci: Add rc branch prefix trigger for integration tests (#43)
asafgardin Jan 29, 2024
dc88a83
fix: types
asafgardin Jan 24, 2024
e119340
test: Added some integration tests
asafgardin Jan 24, 2024
e7b461a
test: improvements
asafgardin Jan 25, 2024
dc18348
test: test_paraphrase.py
asafgardin Jan 27, 2024
067497d
fix: doc
asafgardin Jan 27, 2024
a952acd
fix: removed unused comment
asafgardin Jan 27, 2024
5d70399
test: test_summarize.py
asafgardin Jan 27, 2024
3601984
test: Added tests for test_summarize_by_segment.py
asafgardin Jan 28, 2024
85669a2
test: test_segmentation.py
asafgardin Jan 28, 2024
dfc1f77
fix: file id in library response
asafgardin Jan 28, 2024
079b02e
fix: example for library
asafgardin Jan 28, 2024
ae67c6e
docs: docstrings
asafgardin Jan 30, 2024
42505af
fix: question
asafgardin Jan 30, 2024
224d8ce
fix: CR
asafgardin Jan 30, 2024
b2d7f44
Merge remote-tracking branch 'origin/integration_tests' into integrat…
asafgardin Jan 30, 2024
93caebd
test: Added tests to segment type in embed
asafgardin Jan 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- "rc_*"

env:
POETRY_VERSION: "1.7.1"
Expand Down
11 changes: 10 additions & 1 deletion ai21/clients/common/answer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ def create(
mode: Optional[Mode] = None,
**kwargs,
) -> AnswerResponse:
"""

:param context: A string containing the document context for which the question will be answered
:param question: A string containing the question to be answered based on the provided context.
:param answer_length: Approximate length of the answer in words.
:param mode:
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> AnswerResponse:
Expand All @@ -26,7 +35,7 @@ def _create_body(
self,
context: str,
question: str,
answer_length: Optional[str],
answer_length: Optional[AnswerLength],
mode: Optional[str],
) -> Dict[str, Any]:
return {"context": context, "question": question, "answerLength": answer_length, "mode": mode}
19 changes: 19 additions & 0 deletions ai21/clients/common/chat_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,25 @@ def create(
count_penalty: Optional[Penalty] = None,
**kwargs,
) -> ChatResponse:
"""

:param model: model type you wish to interact with
:param messages: A sequence of messages ingested by the model, which then returns the assistant's response
:param system: Offers the model overarching guidance on its response approach, encapsulating context, tone,
guardrails, and more
:param max_tokens: The maximum number of tokens to generate per result
:param num_results: Number of completions to sample and return.
:param min_tokens: The minimum number of tokens to generate per result.
:param temperature: A value controlling the "creativity" of the model's responses.
:param top_p: A value controlling the diversity of the model's responses.
:param top_k_return: The number of top-scoring tokens to consider for each generation step.
:param stop_sequences: Stops decoding if any of the strings is generated
:param frequency_penalty: A penalty applied to tokens that are frequently generated.
:param presence_penalty: A penalty applied to tokens that are already present in the prompt.
:param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse:
Expand Down
18 changes: 18 additions & 0 deletions ai21/clients/common/completion_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ def create(
epoch: Optional[int] = None,
**kwargs,
) -> CompletionsResponse:
"""
:param model: model type you wish to interact with
:param prompt: Text for model to complete
:param max_tokens: The maximum number of tokens to generate per result
:param num_results: Number of completions to sample and return.
:param min_tokens: The minimum number of tokens to generate per result.
:param temperature: A value controlling the "creativity" of the model's responses.
:param top_p: A value controlling the diversity of the model's responses.
:param top_k_return: The number of top-scoring tokens to consider for each generation step.
:param custom_model:
:param stop_sequences: Stops decoding if any of the strings is generated
:param frequency_penalty: A penalty applied to tokens that are frequently generated.
:param presence_penalty: A penalty applied to tokens that are already present in the prompt.
:param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
:param epoch:
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> CompletionsResponse:
Expand Down
10 changes: 10 additions & 0 deletions ai21/clients/common/custom_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ def create(
num_epochs: Optional[int] = None,
**kwargs,
) -> None:
"""

:param dataset_id: The dataset you want to train your model on.
:param model_name: The name of your trained model
:param model_type: The type of model to train.
:param learning_rate: The learning rate used for training.
:param num_epochs: Number of epochs for training
:param kwargs:
:return:
"""
pass

@abstractmethod
Expand Down
11 changes: 11 additions & 0 deletions ai21/clients/common/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ def create(
split_ratio: Optional[float] = None,
**kwargs,
):
"""

:param file_path: Local path to dataset
:param dataset_name: Dataset name. Must be unique
:param selected_columns: Mapping of the columns in the dataset file to prompt and completion columns.
:param approve_whitespace_correction: Automatically correct examples that violate best practices
:param delete_long_rows: Allow removal of examples where prompt + completion lengths exceeds 2047 tokens
:param split_ratio:
:param kwargs:
:return:
"""
pass

@abstractmethod
Expand Down
8 changes: 8 additions & 0 deletions ai21/clients/common/embed_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ class Embed(ABC):

@abstractmethod
def create(self, texts: List[str], *, type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse:
"""

:param texts: A list of strings, each representing a document or segment of text to be embedded.
:param type: For retrieval/search use cases, indicates whether the texts that were
sent are segments or the query.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> EmbedResponse:
Expand Down
6 changes: 6 additions & 0 deletions ai21/clients/common/gec_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ class GEC(ABC):

@abstractmethod
def create(self, text: str, **kwargs) -> GECResponse:
"""

:param text: The input text to be corrected.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> GECResponse:
Expand Down
7 changes: 7 additions & 0 deletions ai21/clients/common/improvements_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ class Improvements(ABC):

@abstractmethod
def create(self, text: str, types: List[ImprovementType], **kwargs) -> ImprovementsResponse:
"""

:param text: The input text to be improved.
:param types: Types of improvements to apply.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> ImprovementsResponse:
Expand Down
10 changes: 10 additions & 0 deletions ai21/clients/common/paraphrase_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ def create(
end_index: Optional[int] = None,
**kwargs,
) -> ParaphraseResponse:
"""

:param text: The input text to be paraphrased.
:param style: Controls length and tone
:param start_index: Specifies the starting position of the paraphrasing process in the given text
:param end_index: specifies the position of the last character to be paraphrased, including the character
following it. If the parameter is not provided, the default value is set to the length of the given text.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> ParaphraseResponse:
Expand Down
7 changes: 7 additions & 0 deletions ai21/clients/common/segmentation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ class Segmentation(ABC):

@abstractmethod
def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse:
"""

:param source: Raw input text, or URL of a web page.
:param source_type: The type of the source - either TEXT or URL.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> SegmentationResponse:
Expand Down
8 changes: 8 additions & 0 deletions ai21/clients/common/summarize_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def create(
summary_method: Optional[SummaryMethod] = None,
**kwargs,
) -> SummarizeResponse:
"""
:param source: The input text, or URL of a web page to be summarized.
:param source_type: Either TEXT or URL
:param focus: Summaries focused on a topic of your choice.
:param summary_method:
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> SummarizeResponse:
Expand Down
8 changes: 8 additions & 0 deletions ai21/clients/common/summarize_by_segment_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def create(
focus: Optional[str] = None,
**kwargs,
) -> SummarizeBySegmentResponse:
"""

:param source: The input text, or URL of a web page to be summarized.
:param source_type: Either TEXT or URL
:param focus: Summaries focused on a topic of your choice.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> SummarizeBySegmentResponse:
Expand Down
3 changes: 2 additions & 1 deletion ai21/clients/studio/resources/studio_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from ai21.clients.common.embed_base import Embed
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.models.embed_type import EmbedType
from ai21.models.responses.embed_response import EmbedResponse


class StudioEmbed(StudioResource, Embed):
def create(self, texts: List[str], type: Optional[str] = None, **kwargs) -> EmbedResponse:
def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
body = self._create_body(texts=texts, type=type)
response = self._post(url=url, body=body)
Expand Down
1 change: 0 additions & 1 deletion ai21/clients/studio/resources/studio_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def create(
summary_method: Optional[SummaryMethod] = None,
**kwargs,
) -> SummarizeResponse:
# Make a summarize request to the AI21 API. Returns the response either as a string or a AI21Summarize object.
body = self._create_body(
source=source,
source_type=source_type,
Expand Down
2 changes: 1 addition & 1 deletion ai21/models/responses/library_answer_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@dataclass
class SourceDocument(AI21BaseModelMixin):
field_id: str
file_id: str
name: str
highlights: List[str]
public_url: Optional[str] = None
Expand Down
8 changes: 7 additions & 1 deletion examples/studio/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ def validate_file_deleted():
file_path = os.getcwd()

path = os.path.join(file_path, file_name)
file_utils.create_file(file_path, file_name, content="test content" * 100)
_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the
Netherlands. From the 10th to the 16th century, Holland proper was a unified political
region within the Holy Roman Empire as a county ruled by the counts of Holland.
By the 17th century, the province of Holland had risen to become a maritime and economic power,
dominating the other provinces of the newly independent Dutch Republic."""
file_utils.create_file(file_path, file_name, content=_SOURCE_TEXT)

file_id = client.library.files.create(
file_path=path,
Expand All @@ -31,6 +36,7 @@ def validate_file_deleted():
public_url="www.example.com",
)
print(file_id)

files = client.library.files.list()
print(files)
uploaded_file = client.library.files.get(file_id)
Expand Down
2 changes: 1 addition & 1 deletion examples/studio/library_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@


client = AI21Client()
response = client.library.answer.create(question="Where is Thailand?")
response = client.library.answer.create(question="Can you tell me something about Holland?")
print(response)
Empty file.
38 changes: 38 additions & 0 deletions tests/integration_tests/clients/studio/test_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
from ai21 import AI21Client
from ai21.models import AnswerLength, Mode

_CONTEXT = (
"Holland is a geographical region[2] and former province on the western coast of"
" the Netherlands. From the "
"10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county "
"ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and "
"economic power, dominating the other provinces of the newly independent Dutch Republic."
)


@pytest.mark.parametrize(
ids=[
"when_answer_is_in_context",
"when_answer_not_in_context",
],
argnames=["question", "is_answer_in_context", "expected_answer_type"],
argvalues=[
("When did Holland become an economic power?", True, str),
("Is the ocean blue?", False, None),
],
)
def test_answer(question: str, is_answer_in_context: bool, expected_answer_type: type):
client = AI21Client()
response = client.answer.create(
context=_CONTEXT,
question=question,
answer_length=AnswerLength.LONG,
mode=Mode.FLEXIBLE,
)

assert response.answer_in_context == is_answer_in_context
if is_answer_in_context:
assert isinstance(response.answer, str)
else:
assert response.answer is None
Loading