diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 83291bbc..ee23e90d 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -4,6 +4,7 @@ on: push: branches: - main + - "rc_*" env: POETRY_VERSION: "1.7.1" diff --git a/ai21/clients/common/answer_base.py b/ai21/clients/common/answer_base.py index a1543646..1821db9a 100644 --- a/ai21/clients/common/answer_base.py +++ b/ai21/clients/common/answer_base.py @@ -17,6 +17,15 @@ def create( mode: Optional[Mode] = None, **kwargs, ) -> AnswerResponse: + """ + + :param context: A string containing the document context for which the question will be answered + :param question: A string containing the question to be answered based on the provided context. + :param answer_length: Approximate length of the answer in words. + :param mode: + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> AnswerResponse: @@ -26,7 +35,7 @@ def _create_body( self, context: str, question: str, - answer_length: Optional[str], + answer_length: Optional[AnswerLength], mode: Optional[str], ) -> Dict[str, Any]: return {"context": context, "question": question, "answerLength": answer_length, "mode": mode} diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index e73dba3c..03037491 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -28,6 +28,25 @@ def create( count_penalty: Optional[Penalty] = None, **kwargs, ) -> ChatResponse: + """ + + :param model: model type you wish to interact with + :param messages: A sequence of messages ingested by the model, which then returns the assistant's response + :param system: Offers the model overarching guidance on its response approach, encapsulating context, tone, + guardrails, and more + :param max_tokens: The maximum number of tokens to generate per result + :param num_results: Number of completions to sample and return. + :param min_tokens: The minimum number of tokens to generate per result. + :param temperature: A value controlling the "creativity" of the model's responses. + :param top_p: A value controlling the diversity of the model's responses. + :param top_k_return: The number of top-scoring tokens to consider for each generation step. + :param stop_sequences: Stops decoding if any of the strings is generated + :param frequency_penalty: A penalty applied to tokens that are frequently generated. + :param presence_penalty: A penalty applied to tokens that are already present in the prompt. + :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse: diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index a2bc8c3d..06fef7d8 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -28,6 +28,24 @@ def create( epoch: Optional[int] = None, **kwargs, ) -> CompletionsResponse: + """ + :param model: model type you wish to interact with + :param prompt: Text for model to complete + :param max_tokens: The maximum number of tokens to generate per result + :param num_results: Number of completions to sample and return. + :param min_tokens: The minimum number of tokens to generate per result. + :param temperature: A value controlling the "creativity" of the model's responses. + :param top_p: A value controlling the diversity of the model's responses. + :param top_k_return: The number of top-scoring tokens to consider for each generation step. + :param custom_model: + :param stop_sequences: Stops decoding if any of the strings is generated + :param frequency_penalty: A penalty applied to tokens that are frequently generated. + :param presence_penalty: A penalty applied to tokens that are already present in the prompt. + :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses + :param epoch: + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> CompletionsResponse: diff --git a/ai21/clients/common/custom_model_base.py b/ai21/clients/common/custom_model_base.py index 7d3b55ae..303b1adf 100644 --- a/ai21/clients/common/custom_model_base.py +++ b/ai21/clients/common/custom_model_base.py @@ -18,6 +18,16 @@ def create( num_epochs: Optional[int] = None, **kwargs, ) -> None: + """ + + :param dataset_id: The dataset you want to train your model on. + :param model_name: The name of your trained model + :param model_type: The type of model to train. + :param learning_rate: The learning rate used for training. + :param num_epochs: Number of epochs for training + :param kwargs: + :return: + """ pass @abstractmethod diff --git a/ai21/clients/common/dataset_base.py b/ai21/clients/common/dataset_base.py index 9fa57f85..732dee39 100644 --- a/ai21/clients/common/dataset_base.py +++ b/ai21/clients/common/dataset_base.py @@ -19,6 +19,17 @@ def create( split_ratio: Optional[float] = None, **kwargs, ): + """ + + :param file_path: Local path to dataset + :param dataset_name: Dataset name. Must be unique + :param selected_columns: Mapping of the columns in the dataset file to prompt and completion columns. + :param approve_whitespace_correction: Automatically correct examples that violate best practices + :param delete_long_rows: Allow removal of examples where prompt + completion lengths exceeds 2047 tokens + :param split_ratio: + :param kwargs: + :return: + """ pass @abstractmethod diff --git a/ai21/clients/common/embed_base.py b/ai21/clients/common/embed_base.py index baadd4ec..aaf9363e 100644 --- a/ai21/clients/common/embed_base.py +++ b/ai21/clients/common/embed_base.py @@ -10,6 +10,14 @@ class Embed(ABC): @abstractmethod def create(self, texts: List[str], *, type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: + """ + + :param texts: A list of strings, each representing a document or segment of text to be embedded. + :param type: For retrieval/search use cases, indicates whether the texts that were + sent are segments or the query. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> EmbedResponse: diff --git a/ai21/clients/common/gec_base.py b/ai21/clients/common/gec_base.py index 8de743e2..091e6427 100644 --- a/ai21/clients/common/gec_base.py +++ b/ai21/clients/common/gec_base.py @@ -9,6 +9,12 @@ class GEC(ABC): @abstractmethod def create(self, text: str, **kwargs) -> GECResponse: + """ + + :param text: The input text to be corrected. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> GECResponse: diff --git a/ai21/clients/common/improvements_base.py b/ai21/clients/common/improvements_base.py index df912e1d..df13fe58 100644 --- a/ai21/clients/common/improvements_base.py +++ b/ai21/clients/common/improvements_base.py @@ -10,6 +10,13 @@ class Improvements(ABC): @abstractmethod def create(self, text: str, types: List[ImprovementType], **kwargs) -> ImprovementsResponse: + """ + + :param text: The input text to be improved. + :param types: Types of improvements to apply. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> ImprovementsResponse: diff --git a/ai21/clients/common/paraphrase_base.py b/ai21/clients/common/paraphrase_base.py index 917cdd75..3c01bbb7 100644 --- a/ai21/clients/common/paraphrase_base.py +++ b/ai21/clients/common/paraphrase_base.py @@ -18,6 +18,16 @@ def create( end_index: Optional[int] = None, **kwargs, ) -> ParaphraseResponse: + """ + + :param text: The input text to be paraphrased. + :param style: Controls length and tone + :param start_index: Specifies the starting position of the paraphrasing process in the given text + :param end_index: specifies the position of the last character to be paraphrased, including the character + following it. If the parameter is not provided, the default value is set to the length of the given text. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> ParaphraseResponse: diff --git a/ai21/clients/common/segmentation_base.py b/ai21/clients/common/segmentation_base.py index c4f658a9..97c74104 100644 --- a/ai21/clients/common/segmentation_base.py +++ b/ai21/clients/common/segmentation_base.py @@ -10,6 +10,13 @@ class Segmentation(ABC): @abstractmethod def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: + """ + + :param source: Raw input text, or URL of a web page. + :param source_type: The type of the source - either TEXT or URL. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> SegmentationResponse: diff --git a/ai21/clients/common/summarize_base.py b/ai21/clients/common/summarize_base.py index 85cec2f5..2a70dfcd 100644 --- a/ai21/clients/common/summarize_base.py +++ b/ai21/clients/common/summarize_base.py @@ -16,6 +16,14 @@ def create( summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: + """ + :param source: The input text, or URL of a web page to be summarized. + :param source_type: Either TEXT or URL + :param focus: Summaries focused on a topic of your choice. + :param summary_method: + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> SummarizeResponse: diff --git a/ai21/clients/common/summarize_by_segment_base.py b/ai21/clients/common/summarize_by_segment_base.py index 40a4abfa..236337de 100644 --- a/ai21/clients/common/summarize_by_segment_base.py +++ b/ai21/clients/common/summarize_by_segment_base.py @@ -19,6 +19,14 @@ def create( focus: Optional[str] = None, **kwargs, ) -> SummarizeBySegmentResponse: + """ + + :param source: The input text, or URL of a web page to be summarized. + :param source_type: Either TEXT or URL + :param focus: Summaries focused on a topic of your choice. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> SummarizeBySegmentResponse: diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py index 7495ea67..7e7c8fad 100644 --- a/ai21/clients/studio/resources/studio_embed.py +++ b/ai21/clients/studio/resources/studio_embed.py @@ -2,11 +2,12 @@ from ai21.clients.common.embed_base import Embed from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models.embed_type import EmbedType from ai21.models.responses.embed_response import EmbedResponse class StudioEmbed(StudioResource, Embed): - def create(self, texts: List[str], type: Optional[str] = None, **kwargs) -> EmbedResponse: + def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(texts=texts, type=type) response = self._post(url=url, body=body) diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index b2b5f860..4180ff52 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -18,7 +18,6 @@ def create( summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: - # Make a summarize request to the AI21 API. Returns the response either as a string or a AI21Summarize object. body = self._create_body( source=source, source_type=source_type, diff --git a/ai21/models/responses/library_answer_response.py b/ai21/models/responses/library_answer_response.py index 36341eda..28fab165 100644 --- a/ai21/models/responses/library_answer_response.py +++ b/ai21/models/responses/library_answer_response.py @@ -6,7 +6,7 @@ @dataclass class SourceDocument(AI21BaseModelMixin): - field_id: str + file_id: str name: str highlights: List[str] public_url: Optional[str] = None diff --git a/examples/studio/library.py b/examples/studio/library.py index d693d697..ca8e4840 100644 --- a/examples/studio/library.py +++ b/examples/studio/library.py @@ -22,7 +22,12 @@ def validate_file_deleted(): file_path = os.getcwd() path = os.path.join(file_path, file_name) -file_utils.create_file(file_path, file_name, content="test content" * 100) +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" +file_utils.create_file(file_path, file_name, content=_SOURCE_TEXT) file_id = client.library.files.create( file_path=path, @@ -31,6 +36,7 @@ def validate_file_deleted(): public_url="www.example.com", ) print(file_id) + files = client.library.files.list() print(files) uploaded_file = client.library.files.get(file_id) diff --git a/examples/studio/library_answer.py b/examples/studio/library_answer.py index 54b2bb1d..20d46402 100644 --- a/examples/studio/library_answer.py +++ b/examples/studio/library_answer.py @@ -2,5 +2,5 @@ client = AI21Client() -response = client.library.answer.create(question="Where is Thailand?") +response = client.library.answer.create(question="Can you tell me something about Holland?") print(response) diff --git a/tests/integration_tests/clients/studio/__init__.py b/tests/integration_tests/clients/studio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/clients/studio/test_answer.py b/tests/integration_tests/clients/studio/test_answer.py new file mode 100644 index 00000000..51dd4fa2 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_answer.py @@ -0,0 +1,38 @@ +import pytest +from ai21 import AI21Client +from ai21.models import AnswerLength, Mode + +_CONTEXT = ( + "Holland is a geographical region[2] and former province on the western coast of" + " the Netherlands. From the " + "10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county " + "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " + "economic power, dominating the other provinces of the newly independent Dutch Republic." +) + + +@pytest.mark.parametrize( + ids=[ + "when_answer_is_in_context", + "when_answer_not_in_context", + ], + argnames=["question", "is_answer_in_context", "expected_answer_type"], + argvalues=[ + ("When did Holland become an economic power?", True, str), + ("Is the ocean blue?", False, None), + ], +) +def test_answer(question: str, is_answer_in_context: bool, expected_answer_type: type): + client = AI21Client() + response = client.answer.create( + context=_CONTEXT, + question=question, + answer_length=AnswerLength.LONG, + mode=Mode.FLEXIBLE, + ) + + assert response.answer_in_context == is_answer_in_context + if is_answer_in_context: + assert isinstance(response.answer, str) + else: + assert response.answer is None diff --git a/tests/integration_tests/clients/studio/test_chat.py b/tests/integration_tests/clients/studio/test_chat.py new file mode 100644 index 00000000..70d26761 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_chat.py @@ -0,0 +1,94 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import ChatMessage, RoleType, Penalty, FinishReason + +_MODEL = "j2-ultra" +_MESSAGES = [ + ChatMessage( + text="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", + role=RoleType.USER, + ), +] +_SYSTEM = "You are a teacher in a public school" + + +def test_chat(): + num_results = 5 + messages = _MESSAGES + + client = AI21Client() + response = client.chat.create( + system=_SYSTEM, + messages=messages, + num_results=num_results, + max_tokens=64, + temperature=0.7, + min_tokens=1, + stop_sequences=["\n"], + top_p=0.3, + top_k_return=0, + model=_MODEL, + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + frequency_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + presence_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + ) + + assert response.outputs[0].role == RoleType.ASSISTANT + assert isinstance(response.outputs[0].text, str) + assert response.outputs[0].finish_reason == FinishReason(reason="endoftext") + + assert len(response.outputs) == num_results + + +@pytest.mark.parametrize( + ids=[ + "finish_reason_length", + "finish_reason_endoftext", + "finish_reason_stop_sequence", + ], + argnames=["max_tokens", "stop_sequences", "reason"], + argvalues=[ + (2, "##", "length"), + (100, "##", "endoftext"), + (20, ".", "stop"), + ], +) +def test_chat_when_finish_reason_defined__should_halt_on_expected_reason( + max_tokens: int, stop_sequences: str, reason: str +): + client = AI21Client() + response = client.chat.create( + messages=_MESSAGES, + system=_SYSTEM, + max_tokens=max_tokens, + model="j2-ultra", + temperature=1, + top_p=0, + num_results=1, + stop_sequences=[stop_sequences], + top_k_return=0, + ) + + assert response.outputs[0].finish_reason.reason == reason diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py new file mode 100644 index 00000000..9938fa2a --- /dev/null +++ b/tests/integration_tests/clients/studio/test_completion.py @@ -0,0 +1,112 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import Penalty + +_PROMPT = """ +User: Haven't received a confirmation email for my order #12345. +Assistant: I'm sorry to hear that. I'll look into it right away. +User: Can you please let me know when I can expect to receive it? +""" + + +def test_completion(): + num_results = 3 + + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model="j2-ultra", + temperature=0.7, + top_p=0.2, + top_k_return=0.2, + stop_sequences=["##"], + num_results=num_results, + custom_model=None, + epoch=1, + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + frequency_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + presence_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + ) + + assert response.prompt.text == _PROMPT + assert len(response.completions) == num_results + # Check the results aren't all the same + assert len([completion.data.text for completion in response.completions]) == num_results + for completion in response.completions: + assert isinstance(completion.data.text, str) + + +def test_completion_when_temperature_1_and_top_p_is_0__should_return_same_response(): + num_results = 5 + + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model="j2-ultra", + temperature=1, + top_p=0, + top_k_return=0, + num_results=num_results, + epoch=1, + ) + + assert response.prompt.text == _PROMPT + assert len(response.completions) == num_results + # Verify all results are the same + assert len(set([completion.data.text for completion in response.completions])) == 1 + + +@pytest.mark.parametrize( + ids=[ + "finish_reason_length", + "finish_reason_endoftext", + "finish_reason_stop_sequence", + ], + argnames=["max_tokens", "stop_sequences", "reason"], + argvalues=[ + (10, "##", "length"), + (100, "##", "endoftext"), + (50, "\n", "stop"), + ], +) +def test_completion_when_finish_reason_defined__should_halt_on_expected_reason( + max_tokens: int, stop_sequences: str, reason: str +): + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=max_tokens, + model="j2-ultra", + temperature=1, + top_p=0, + num_results=1, + stop_sequences=[stop_sequences], + top_k_return=0, + epoch=1, + ) + + assert response.completions[0].finish_reason.reason == reason diff --git a/tests/integration_tests/clients/studio/test_embed.py b/tests/integration_tests/clients/studio/test_embed.py new file mode 100644 index 00000000..8fc77dc5 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_embed.py @@ -0,0 +1,37 @@ +from typing import List + +import pytest +from ai21 import AI21Client +from ai21.models import EmbedType + +_TEXT_0 = "Holland is a geographical region and former province on the western coast of the Netherlands." +_TEXT_1 = "Germany is a country in Central Europe. It is the second-most populous country in Europe after Russia" + +_SEGMENT_0 = "The sun sets behind the mountains," +_SEGMENT_1 = "casting a warm glow over" +_SEGMENT_2 = "the city of Amsterdam." + + +@pytest.mark.parametrize( + ids=[ + "when_single_text_and_query__should_return_single_embedding", + "when_multiple_text_and_query__should_return_multiple_embeddings", + "when_single_text_and_segment__should_return_single_embedding", + "when_multiple_text_and_segment__should_return_multiple_embeddings", + ], + argnames=["texts", "type"], + argvalues=[ + ([_TEXT_0], EmbedType.QUERY), + ([_TEXT_0, _TEXT_1], EmbedType.QUERY), + ([_SEGMENT_0], EmbedType.SEGMENT), + ([_SEGMENT_0, _SEGMENT_1, _SEGMENT_2], EmbedType.SEGMENT), + ], +) +def test_embed(texts: List[str], type: EmbedType): + client = AI21Client() + response = client.embed.create( + texts=texts, + type=type, + ) + + assert len(response.results) == len(texts) diff --git a/tests/integration_tests/clients/studio/test_gec.py b/tests/integration_tests/clients/studio/test_gec.py new file mode 100644 index 00000000..51418cf1 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_gec.py @@ -0,0 +1,31 @@ +import pytest +from ai21 import AI21Client +from ai21.models import CorrectionType + + +@pytest.mark.parametrize( + ids=[ + "should_fix_spelling", + "should_fix_grammar", + "should_fix_missing_word", + "should_fix_punctuation", + "should_fix_wrong_word", + ], + argnames=["text", "correction_type", "expected_suggestion"], + argvalues=[ + ("jazzz is music", CorrectionType.SPELLING, "Jazz"), + ("You am nice", CorrectionType.GRAMMAR, "are"), + ( + "He stared out the window, lost in thought, as the raindrops against the glass.", + CorrectionType.MISSING_WORD, + "raindrops fell against", + ), + ("He is a well known author.", CorrectionType.PUNCTUATION, "well-known"), + ("He is a dog-known author.", CorrectionType.WRONG_WORD, "well-known"), + ], +) +def test_gec(text: str, correction_type: CorrectionType, expected_suggestion: str): + client = AI21Client() + response = client.gec.create(text=text) + assert response.corrections[0].suggestion == expected_suggestion + assert response.corrections[0].correction_type == correction_type diff --git a/tests/integration_tests/clients/studio/test_improvements.py b/tests/integration_tests/clients/studio/test_improvements.py new file mode 100644 index 00000000..3488f781 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_improvements.py @@ -0,0 +1,13 @@ +from ai21 import AI21Client +from ai21.models import ImprovementType + + +def test_improvements(): + client = AI21Client() + response = client.improvements.create( + text="Affiliated with the profession of project management," + " I have ameliorated myself with a different set of hard skills as well as soft skills.", + types=[ImprovementType.FLUENCY], + ) + + assert len(response.improvements) > 0 diff --git a/tests/integration_tests/clients/studio/test_paraphrase.py b/tests/integration_tests/clients/studio/test_paraphrase.py new file mode 100644 index 00000000..a7ba93aa --- /dev/null +++ b/tests/integration_tests/clients/studio/test_paraphrase.py @@ -0,0 +1,51 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import ParaphraseStyleType + + +def test_paraphrase(): + client = AI21Client() + response = client.paraphrase.create( + text="The cat (Felis catus) is a domestic species of small carnivorous mammal", + style=ParaphraseStyleType.FORMAL, + start_index=0, + end_index=20, + ) + for suggestion in response.suggestions: + print(suggestion.text) + assert len(response.suggestions) > 0 + + +def test_paraphrase__when_start_and_end_index_is_small__should_not_return_suggestions(): + client = AI21Client() + response = client.paraphrase.create( + text="The cat (Felis catus) is a domestic species of small carnivorous mammal", + style=ParaphraseStyleType.GENERAL, + start_index=0, + end_index=5, + ) + assert len(response.suggestions) == 0 + + +@pytest.mark.parametrize( + ids=["when_general", "when_casual", "when_long", "when_short", "when_formal"], + argnames=["style"], + argvalues=[ + (ParaphraseStyleType.GENERAL,), + (ParaphraseStyleType.CASUAL,), + (ParaphraseStyleType.LONG,), + (ParaphraseStyleType.SHORT,), + (ParaphraseStyleType.FORMAL,), + ], +) +def test_paraphrase_styles(style: ParaphraseStyleType): + client = AI21Client() + response = client.paraphrase.create( + text="Today is a beautiful day.", + style=style, + start_index=0, + end_index=25, + ) + + assert len(response.suggestions) > 0 diff --git a/tests/integration_tests/clients/studio/test_segmentation.py b/tests/integration_tests/clients/studio/test_segmentation.py new file mode 100644 index 00000000..ede8707c --- /dev/null +++ b/tests/integration_tests/clients/studio/test_segmentation.py @@ -0,0 +1,55 @@ +import pytest +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text__should_return_a_segments", + "when_source_is_url__should_return_a_segments", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.TEXT), + (_SOURCE_URL, DocumentType.URL), + ], +) +def test_segmentation(source: str, source_type: DocumentType): + client = AI21Client() + + response = client.segmentation.create( + source=source, + source_type=source_type, + ) + + assert isinstance(response.segments[0].segment_text, str) + assert response.segments[0].segment_type is not None + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + # "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + # (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_segmentation__source_and_source_type_misalignment(source: str, source_type: DocumentType): + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.segmentation.create( + source=source, + source_type=source_type, + ) diff --git a/tests/integration_tests/clients/studio/test_summarize.py b/tests/integration_tests/clients/studio/test_summarize.py new file mode 100644 index 00000000..6c7ae4e9 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_summarize.py @@ -0,0 +1,62 @@ +import pytest + +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType, SummaryMethod + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text__should_return_a_suggestion", + "when_source_is_url__should_return_a_suggestion", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.TEXT), + (_SOURCE_URL, DocumentType.URL), + ], +) +def test_summarize(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + response = client.summarize.create( + source=source, + source_type=source_type, + summary_method=SummaryMethod.SEGMENTS, + focus=focus, + ) + assert response.summary is not None + assert focus in response.summary + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_summarize__source_and_source_type_misalignment(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.summarize.create( + source=source, + source_type=source_type, + summary_method=SummaryMethod.SEGMENTS, + focus=focus, + ) diff --git a/tests/integration_tests/clients/studio/test_summarize_by_segment.py b/tests/integration_tests/clients/studio/test_summarize_by_segment.py new file mode 100644 index 00000000..f39dd308 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_summarize_by_segment.py @@ -0,0 +1,66 @@ +import pytest + +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the Netherlands. + From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +def test_summarize_by_segment__when_text__should_return_response(): + client = AI21Client() + response = client.summarize_by_segment.create( + source=_SOURCE_TEXT, + source_type=DocumentType.TEXT, + focus="Holland", + ) + assert isinstance(response.segments[0].segment_text, str) + assert response.segments[0].segment_html is None + assert isinstance(response.segments[0].summary, str) + assert len(response.segments[0].highlights) > 0 + assert response.segments[0].segment_type == "normal_text" + assert response.segments[0].has_summary + + +def test_summarize_by_segment__when_url__should_return_response(): + client = AI21Client() + response = client.summarize_by_segment.create( + source=_SOURCE_URL, + source_type=DocumentType.URL, + focus="Holland", + ) + assert isinstance(response.segments[0].segment_text, str) + assert isinstance(response.segments[0].segment_html, str) + assert isinstance(response.segments[0].summary, str) + assert response.segments[0].segment_type == "normal_text" + assert len(response.segments[0].highlights) > 0 + assert response.segments[0].has_summary + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_summarize_by_segment__source_and_source_type_misalignment(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.summarize_by_segment.create( + source=source, + source_type=source_type, + focus=focus, + )