From 75e4d4e7b1247b1f70285d172ccb399e053cff41 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 24 Jan 2024 17:55:01 +0200 Subject: [PATCH 01/27] fix: types --- ai21/clients/common/answer_base.py | 2 +- ai21/clients/studio/resources/studio_embed.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ai21/clients/common/answer_base.py b/ai21/clients/common/answer_base.py index a1543646..03b6c71d 100644 --- a/ai21/clients/common/answer_base.py +++ b/ai21/clients/common/answer_base.py @@ -26,7 +26,7 @@ def _create_body( self, context: str, question: str, - answer_length: Optional[str], + answer_length: Optional[AnswerLength], mode: Optional[str], ) -> Dict[str, Any]: return {"context": context, "question": question, "answerLength": answer_length, "mode": mode} diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py index 7495ea67..7e7c8fad 100644 --- a/ai21/clients/studio/resources/studio_embed.py +++ b/ai21/clients/studio/resources/studio_embed.py @@ -2,11 +2,12 @@ from ai21.clients.common.embed_base import Embed from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models.embed_type import EmbedType from ai21.models.responses.embed_response import EmbedResponse class StudioEmbed(StudioResource, Embed): - def create(self, texts: List[str], type: Optional[str] = None, **kwargs) -> EmbedResponse: + def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(texts=texts, type=type) response = self._post(url=url, body=body) From 7242325c41e3e1ad82d63331ca0b5876ad27005f Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 24 Jan 2024 17:55:21 +0200 Subject: [PATCH 02/27] test: Added some integration tests --- .../clients/studio/__init__.py | 0 .../clients/studio/test_answer.py | 51 +++++++ .../clients/studio/test_chat.py | 94 ++++++++++++ .../clients/studio/test_completion.py | 137 ++++++++++++++++++ .../clients/studio/test_embed.py | 29 ++++ .../clients/studio/test_gec.py | 37 +++++ 6 files changed, 348 insertions(+) create mode 100644 tests/integration_tests/clients/studio/__init__.py create mode 100644 tests/integration_tests/clients/studio/test_answer.py create mode 100644 tests/integration_tests/clients/studio/test_chat.py create mode 100644 tests/integration_tests/clients/studio/test_completion.py create mode 100644 tests/integration_tests/clients/studio/test_embed.py create mode 100644 tests/integration_tests/clients/studio/test_gec.py diff --git a/tests/integration_tests/clients/studio/__init__.py b/tests/integration_tests/clients/studio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/clients/studio/test_answer.py b/tests/integration_tests/clients/studio/test_answer.py new file mode 100644 index 00000000..a611085f --- /dev/null +++ b/tests/integration_tests/clients/studio/test_answer.py @@ -0,0 +1,51 @@ +import pytest +from ai21 import AI21Client +from ai21.models import AnswerLength, Mode + +_CONTEXT = ( + "Holland is a geographical region[2] and former province on the western coast of" + " the Netherlands.[2] From the " + "10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county " + "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " + "economic power, dominating the other provinces of the newly independent Dutch Republic." +) + + +@pytest.mark.parametrize( + ids=[ + "when_answer_is_in_context", + "when_answer_not_in_context", + ], + argnames=["question", "is_answer_in_context", "expected_answer_type"], + argvalues=[ + ("When did Holland become an economic power?", True, str), + ("Is the ocean blue?", False, None), + ], +) +def test_answer(question: str, is_answer_in_context: bool, expected_answer_type: type): + client = AI21Client() + response = client.answer.create( + context=_CONTEXT, + question=question, + answer_length=AnswerLength.LONG, + mode=Mode.FLEXIBLE, + ) + + assert response.answer_in_context == is_answer_in_context + if is_answer_in_context: + assert isinstance(response.answer, str) + else: + assert response.answer is None + + +def test_answer__when_answer_length(): + client = AI21Client() + response = client.answer.create( + context=_CONTEXT, + question="Can you please tell me everything you know about Holland?", + answer_length=AnswerLength.SHORT, + mode="flexible", + ) + print("--------\n") + print(response.answer) + print(len(response.answer)) diff --git a/tests/integration_tests/clients/studio/test_chat.py b/tests/integration_tests/clients/studio/test_chat.py new file mode 100644 index 00000000..70d26761 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_chat.py @@ -0,0 +1,94 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import ChatMessage, RoleType, Penalty, FinishReason + +_MODEL = "j2-ultra" +_MESSAGES = [ + ChatMessage( + text="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", + role=RoleType.USER, + ), +] +_SYSTEM = "You are a teacher in a public school" + + +def test_chat(): + num_results = 5 + messages = _MESSAGES + + client = AI21Client() + response = client.chat.create( + system=_SYSTEM, + messages=messages, + num_results=num_results, + max_tokens=64, + temperature=0.7, + min_tokens=1, + stop_sequences=["\n"], + top_p=0.3, + top_k_return=0, + model=_MODEL, + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + frequency_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + presence_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + ) + + assert response.outputs[0].role == RoleType.ASSISTANT + assert isinstance(response.outputs[0].text, str) + assert response.outputs[0].finish_reason == FinishReason(reason="endoftext") + + assert len(response.outputs) == num_results + + +@pytest.mark.parametrize( + ids=[ + "finish_reason_length", + "finish_reason_endoftext", + "finish_reason_stop_sequence", + ], + argnames=["max_tokens", "stop_sequences", "reason"], + argvalues=[ + (2, "##", "length"), + (100, "##", "endoftext"), + (20, ".", "stop"), + ], +) +def test_chat_when_finish_reason_defined__should_halt_on_expected_reason( + max_tokens: int, stop_sequences: str, reason: str +): + client = AI21Client() + response = client.chat.create( + messages=_MESSAGES, + system=_SYSTEM, + max_tokens=max_tokens, + model="j2-ultra", + temperature=1, + top_p=0, + num_results=1, + stop_sequences=[stop_sequences], + top_k_return=0, + ) + + assert response.outputs[0].finish_reason.reason == reason diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py new file mode 100644 index 00000000..440b5b89 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_completion.py @@ -0,0 +1,137 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import Penalty + +_PROMPT = ( + "The following is a conversation between a user of an eCommerce store and a user operation" + " associate called Max. Max is very kind and keen to help." + " The following are important points about the business policies:\n- " + "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" + " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " + "Hi there, happy to help!\nUser: Is there no way to return a product?" + " I got your blue T-Shirt size small but it doesn't fit.\n" + "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" + "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" + "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" + " associate called Max. Max is very kind and keen to help. The following are important points about" + " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" + 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' + "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" + " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" + "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" + " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." + " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" + "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" + " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" + " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" + "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" + "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" + "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" + " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" + " are important points about the business policies:\n- Delivery takes up to 5 days\n" + "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" + "User: Hi, I have a question for you" +) + + +def test_completion(): + num_results = 3 + + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model="j2-ultra", + temperature=1, + top_p=0.2, + top_k_return=0, + stop_sequences=["##"], + num_results=num_results, + custom_model=None, + epoch=1, + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + frequency_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + presence_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + ) + + assert response.prompt.text == _PROMPT + assert len(response.completions) == num_results + # Check the results aren't all the same + assert len(set([completion.data.text for completion in response.completions])) == num_results + for completion in response.completions: + assert isinstance(completion.data.text, str) + + +def test_completion_when_temperature_1_and_top_p_is_0__should_return_same_response(): + num_results = 5 + + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model="j2-ultra", + temperature=1, + top_p=0, + top_k_return=0, + num_results=num_results, + epoch=1, + ) + + assert response.prompt.text == _PROMPT + assert len(response.completions) == num_results + # Verify all results are the same + assert len(set([completion.data.text for completion in response.completions])) == 1 + + +@pytest.mark.parametrize( + ids=[ + "finish_reason_length", + "finish_reason_endoftext", + "finish_reason_stop_sequence", + ], + argnames=["max_tokens", "stop_sequences", "reason"], + argvalues=[ + (10, "##", "length"), + (100, "##", "endoftext"), + (50, "\n", "stop"), + ], +) +def test_completion_when_finish_reason_defined__should_halt_on_expected_reason( + max_tokens: int, stop_sequences: str, reason: str +): + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=max_tokens, + model="j2-ultra", + temperature=1, + top_p=0, + num_results=1, + stop_sequences=[stop_sequences], + top_k_return=0, + epoch=1, + ) + + assert response.completions[0].finish_reason.reason == reason diff --git a/tests/integration_tests/clients/studio/test_embed.py b/tests/integration_tests/clients/studio/test_embed.py new file mode 100644 index 00000000..ea260d76 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_embed.py @@ -0,0 +1,29 @@ +from typing import List + +import pytest +from ai21 import AI21Client +from ai21.models import EmbedType + +_TEXT_0 = "Holland is a geographical region and former province on the western coast of the Netherlands." +_TEXT_1 = "Germany is a country in Central Europe. It is the second-most populous country in Europe after Russia" + + +@pytest.mark.parametrize( + ids=[ + "when_single_text__should_return_single_embedding", + "when_multiple_text__should_return_multiple_embeddings", + ], + argnames=["texts"], + argvalues=[ + ([_TEXT_0],), + ([_TEXT_0, _TEXT_1],), + ], +) +def test_embed(texts: List[str]): + client = AI21Client() + response = client.embed.create( + texts=texts, + type=EmbedType.QUERY, + ) + + assert len(response.results) == len(texts) diff --git a/tests/integration_tests/clients/studio/test_gec.py b/tests/integration_tests/clients/studio/test_gec.py new file mode 100644 index 00000000..1503c15c --- /dev/null +++ b/tests/integration_tests/clients/studio/test_gec.py @@ -0,0 +1,37 @@ +import pytest +from ai21 import AI21Client +from ai21.models import CorrectionType + + +@pytest.mark.parametrize( + ids=[ + "should_fix_spelling", + "should_fix_grammar", + "should_fix_missing_word", + "should_fix_punctuation", + "should_fix_wrong_word", + # "should_fix_word_repetition", + ], + argnames=["text", "correction_type", "expected_suggestion"], + argvalues=[ + ("jazzz is music", CorrectionType.SPELLING, "Jazz"), + ("You am nice", CorrectionType.GRAMMAR, "are"), + ( + "He stared out the window, lost in thought, as the raindrops against the glass.", + CorrectionType.MISSING_WORD, + "raindrops fell against", + ), + ("He is a well known author.", CorrectionType.PUNCTUATION, "well-known"), + ("He is a dog-known author.", CorrectionType.WRONG_WORD, "well-known"), + # ( + # "The mountain was tall, and the tall mountain could be seen from miles away.", + # CorrectionType.WORD_REPETITION, + # "like", + # ), + ], +) +def test_gec(text: str, correction_type: CorrectionType, expected_suggestion: str): + client = AI21Client() + response = client.gec.create(text=text) + assert response.corrections[0].suggestion == expected_suggestion + assert response.corrections[0].correction_type == correction_type From 1525a0e5a9fda01daad111574e1b33365abf44c4 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 25 Jan 2024 11:57:09 +0200 Subject: [PATCH 03/27] test: improvements --- .../clients/studio/test_answer.py | 15 +---------- .../clients/studio/test_improvements.py | 13 ++++++++++ .../clients/studio/test_paraphrase.py | 26 +++++++++++++++++++ 3 files changed, 40 insertions(+), 14 deletions(-) create mode 100644 tests/integration_tests/clients/studio/test_improvements.py create mode 100644 tests/integration_tests/clients/studio/test_paraphrase.py diff --git a/tests/integration_tests/clients/studio/test_answer.py b/tests/integration_tests/clients/studio/test_answer.py index a611085f..51dd4fa2 100644 --- a/tests/integration_tests/clients/studio/test_answer.py +++ b/tests/integration_tests/clients/studio/test_answer.py @@ -4,7 +4,7 @@ _CONTEXT = ( "Holland is a geographical region[2] and former province on the western coast of" - " the Netherlands.[2] From the " + " the Netherlands. From the " "10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county " "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " "economic power, dominating the other provinces of the newly independent Dutch Republic." @@ -36,16 +36,3 @@ def test_answer(question: str, is_answer_in_context: bool, expected_answer_type: assert isinstance(response.answer, str) else: assert response.answer is None - - -def test_answer__when_answer_length(): - client = AI21Client() - response = client.answer.create( - context=_CONTEXT, - question="Can you please tell me everything you know about Holland?", - answer_length=AnswerLength.SHORT, - mode="flexible", - ) - print("--------\n") - print(response.answer) - print(len(response.answer)) diff --git a/tests/integration_tests/clients/studio/test_improvements.py b/tests/integration_tests/clients/studio/test_improvements.py new file mode 100644 index 00000000..3488f781 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_improvements.py @@ -0,0 +1,13 @@ +from ai21 import AI21Client +from ai21.models import ImprovementType + + +def test_improvements(): + client = AI21Client() + response = client.improvements.create( + text="Affiliated with the profession of project management," + " I have ameliorated myself with a different set of hard skills as well as soft skills.", + types=[ImprovementType.FLUENCY], + ) + + assert len(response.improvements) > 0 diff --git a/tests/integration_tests/clients/studio/test_paraphrase.py b/tests/integration_tests/clients/studio/test_paraphrase.py new file mode 100644 index 00000000..ed9f7272 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_paraphrase.py @@ -0,0 +1,26 @@ +from ai21 import AI21Client +from ai21.models import ParaphraseStyleType + + +def test_paraphrase(): + client = AI21Client() + response = client.paraphrase.create( + text="The cat (Felis catus) is a domestic species of small carnivorous mammal", + style=ParaphraseStyleType.FORMAL, + start_index=0, + end_index=20, + ) + for suggestion in response.suggestions: + print(suggestion.text) + assert len(response.suggestions) > 0 + + +def test_paraphrase__when_start_and_end_index_is_small__should_not_return_suggestions(): + client = AI21Client() + response = client.paraphrase.create( + text="The cat (Felis catus) is a domestic species of small carnivorous mammal", + style=ParaphraseStyleType.GENERAL, + start_index=0, + end_index=5, + ) + assert len(response.suggestions) == 0 From cc7fbe17067538ecc1137408d0d2dd2f5b1661ae Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sat, 27 Jan 2024 19:33:21 +0200 Subject: [PATCH 04/27] test: test_paraphrase.py --- .../clients/studio/test_paraphrase.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/integration_tests/clients/studio/test_paraphrase.py b/tests/integration_tests/clients/studio/test_paraphrase.py index ed9f7272..a7ba93aa 100644 --- a/tests/integration_tests/clients/studio/test_paraphrase.py +++ b/tests/integration_tests/clients/studio/test_paraphrase.py @@ -1,3 +1,5 @@ +import pytest + from ai21 import AI21Client from ai21.models import ParaphraseStyleType @@ -24,3 +26,26 @@ def test_paraphrase__when_start_and_end_index_is_small__should_not_return_sugges end_index=5, ) assert len(response.suggestions) == 0 + + +@pytest.mark.parametrize( + ids=["when_general", "when_casual", "when_long", "when_short", "when_formal"], + argnames=["style"], + argvalues=[ + (ParaphraseStyleType.GENERAL,), + (ParaphraseStyleType.CASUAL,), + (ParaphraseStyleType.LONG,), + (ParaphraseStyleType.SHORT,), + (ParaphraseStyleType.FORMAL,), + ], +) +def test_paraphrase_styles(style: ParaphraseStyleType): + client = AI21Client() + response = client.paraphrase.create( + text="Today is a beautiful day.", + style=style, + start_index=0, + end_index=25, + ) + + assert len(response.suggestions) > 0 From 52f300efdf99b96d492478d4696ab4061447a4dc Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sat, 27 Jan 2024 19:41:24 +0200 Subject: [PATCH 05/27] fix: doc --- ai21/clients/common/summarize_base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ai21/clients/common/summarize_base.py b/ai21/clients/common/summarize_base.py index 85cec2f5..c74d4538 100644 --- a/ai21/clients/common/summarize_base.py +++ b/ai21/clients/common/summarize_base.py @@ -16,6 +16,14 @@ def create( summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: + """ + :param source: The input text, or URL of a web page to be summarized. + :param source_type: Either TEXT or URL + :param focus: Return only summaries focused on a topic of your choice. + :param summary_method: + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> SummarizeResponse: From 3d54cdebf22dda57d172f1302892728ada77dc23 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sat, 27 Jan 2024 19:41:39 +0200 Subject: [PATCH 06/27] fix: removed unused comment --- ai21/clients/studio/resources/studio_summarize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index b2b5f860..4180ff52 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -18,7 +18,6 @@ def create( summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: - # Make a summarize request to the AI21 API. Returns the response either as a string or a AI21Summarize object. body = self._create_body( source=source, source_type=source_type, From 0197c8ca23f7663d4ac8e25b74ed0fa09cc04e32 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sat, 27 Jan 2024 20:11:41 +0200 Subject: [PATCH 07/27] test: test_summarize.py --- .../clients/studio/test_summarize.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/integration_tests/clients/studio/test_summarize.py diff --git a/tests/integration_tests/clients/studio/test_summarize.py b/tests/integration_tests/clients/studio/test_summarize.py new file mode 100644 index 00000000..7153c921 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_summarize.py @@ -0,0 +1,62 @@ +import pytest + +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType, SummaryMethod + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text__should_return_a_suggestion", + "when_source_is_url__should_return_a_suggestion", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.TEXT), + (_SOURCE_URL, DocumentType.URL), + ], +) +def test_summarize(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + response = client.summarize.create( + source=source, + source_type=source_type, + summary_method=SummaryMethod.SEGMENTS, + focus=focus, + ) + assert response.summary is not None + assert focus in response.summary + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_summarize_source_and_source_type_misalignment(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.summarize.create( + source=source, + source_type=source_type, + summary_method=SummaryMethod.SEGMENTS, + focus=focus, + ) From d9c148e7f8f689b6803a44be6d31b2d16e199bf0 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 28 Jan 2024 09:41:09 +0200 Subject: [PATCH 08/27] test: Added tests for test_summarize_by_segment.py --- .../clients/studio/test_summarize.py | 2 +- .../studio/test_summarize_by_segment.py | 66 +++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 tests/integration_tests/clients/studio/test_summarize_by_segment.py diff --git a/tests/integration_tests/clients/studio/test_summarize.py b/tests/integration_tests/clients/studio/test_summarize.py index 7153c921..6c7ae4e9 100644 --- a/tests/integration_tests/clients/studio/test_summarize.py +++ b/tests/integration_tests/clients/studio/test_summarize.py @@ -49,7 +49,7 @@ def test_summarize(source: str, source_type: DocumentType): (_SOURCE_URL, DocumentType.TEXT), ], ) -def test_summarize_source_and_source_type_misalignment(source: str, source_type: DocumentType): +def test_summarize__source_and_source_type_misalignment(source: str, source_type: DocumentType): focus = "Holland" client = AI21Client() diff --git a/tests/integration_tests/clients/studio/test_summarize_by_segment.py b/tests/integration_tests/clients/studio/test_summarize_by_segment.py new file mode 100644 index 00000000..f39dd308 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_summarize_by_segment.py @@ -0,0 +1,66 @@ +import pytest + +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the Netherlands. + From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +def test_summarize_by_segment__when_text__should_return_response(): + client = AI21Client() + response = client.summarize_by_segment.create( + source=_SOURCE_TEXT, + source_type=DocumentType.TEXT, + focus="Holland", + ) + assert isinstance(response.segments[0].segment_text, str) + assert response.segments[0].segment_html is None + assert isinstance(response.segments[0].summary, str) + assert len(response.segments[0].highlights) > 0 + assert response.segments[0].segment_type == "normal_text" + assert response.segments[0].has_summary + + +def test_summarize_by_segment__when_url__should_return_response(): + client = AI21Client() + response = client.summarize_by_segment.create( + source=_SOURCE_URL, + source_type=DocumentType.URL, + focus="Holland", + ) + assert isinstance(response.segments[0].segment_text, str) + assert isinstance(response.segments[0].segment_html, str) + assert isinstance(response.segments[0].summary, str) + assert response.segments[0].segment_type == "normal_text" + assert len(response.segments[0].highlights) > 0 + assert response.segments[0].has_summary + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_summarize_by_segment__source_and_source_type_misalignment(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.summarize_by_segment.create( + source=source, + source_type=source_type, + focus=focus, + ) From e644bd2f8b2de595a048d73d6f62860d2d36235b Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 28 Jan 2024 11:26:03 +0200 Subject: [PATCH 09/27] test: test_segmentation.py --- .../clients/studio/test_segmentation.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/integration_tests/clients/studio/test_segmentation.py diff --git a/tests/integration_tests/clients/studio/test_segmentation.py b/tests/integration_tests/clients/studio/test_segmentation.py new file mode 100644 index 00000000..ede8707c --- /dev/null +++ b/tests/integration_tests/clients/studio/test_segmentation.py @@ -0,0 +1,55 @@ +import pytest +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text__should_return_a_segments", + "when_source_is_url__should_return_a_segments", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.TEXT), + (_SOURCE_URL, DocumentType.URL), + ], +) +def test_segmentation(source: str, source_type: DocumentType): + client = AI21Client() + + response = client.segmentation.create( + source=source, + source_type=source_type, + ) + + assert isinstance(response.segments[0].segment_text, str) + assert response.segments[0].segment_type is not None + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + # "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + # (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_segmentation__source_and_source_type_misalignment(source: str, source_type: DocumentType): + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.segmentation.create( + source=source, + source_type=source_type, + ) From 1adb281bed978a0455b74fc373507f520fea4361 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 28 Jan 2024 14:13:23 +0200 Subject: [PATCH 10/27] fix: file id in library response --- ai21/models/responses/library_answer_response.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ai21/models/responses/library_answer_response.py b/ai21/models/responses/library_answer_response.py index 36341eda..28fab165 100644 --- a/ai21/models/responses/library_answer_response.py +++ b/ai21/models/responses/library_answer_response.py @@ -6,7 +6,7 @@ @dataclass class SourceDocument(AI21BaseModelMixin): - field_id: str + file_id: str name: str highlights: List[str] public_url: Optional[str] = None From 3ad97e4ca439d93275fb1cdb3251ebfaebb18ccc Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 28 Jan 2024 14:13:44 +0200 Subject: [PATCH 11/27] fix: example for library --- examples/studio/library.py | 8 +++++++- examples/studio/library_answer.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/studio/library.py b/examples/studio/library.py index d693d697..ca8e4840 100644 --- a/examples/studio/library.py +++ b/examples/studio/library.py @@ -22,7 +22,12 @@ def validate_file_deleted(): file_path = os.getcwd() path = os.path.join(file_path, file_name) -file_utils.create_file(file_path, file_name, content="test content" * 100) +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" +file_utils.create_file(file_path, file_name, content=_SOURCE_TEXT) file_id = client.library.files.create( file_path=path, @@ -31,6 +36,7 @@ def validate_file_deleted(): public_url="www.example.com", ) print(file_id) + files = client.library.files.list() print(files) uploaded_file = client.library.files.get(file_id) diff --git a/examples/studio/library_answer.py b/examples/studio/library_answer.py index 54b2bb1d..6ec3f092 100644 --- a/examples/studio/library_answer.py +++ b/examples/studio/library_answer.py @@ -2,5 +2,5 @@ client = AI21Client() -response = client.library.answer.create(question="Where is Thailand?") +response = client.library.answer.create(question="Tell me something about Holland") print(response) From 0eacdbb07e62ed415e6c8cfe0cd02b2296c63eb4 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Mon, 29 Jan 2024 17:54:53 +0200 Subject: [PATCH 12/27] ci: Add rc branch prefix trigger for integration tests (#43) * ci: rc branch trigger for integration test * fix: wrapped in quotes --- .github/workflows/integration-tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 83291bbc..ee23e90d 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -4,6 +4,7 @@ on: push: branches: - main + - "rc_*" env: POETRY_VERSION: "1.7.1" From dc88a83fb50aafd255160c3c5992af288fc52141 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 24 Jan 2024 17:55:01 +0200 Subject: [PATCH 13/27] fix: types --- ai21/clients/common/answer_base.py | 2 +- ai21/clients/studio/resources/studio_embed.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ai21/clients/common/answer_base.py b/ai21/clients/common/answer_base.py index a1543646..03b6c71d 100644 --- a/ai21/clients/common/answer_base.py +++ b/ai21/clients/common/answer_base.py @@ -26,7 +26,7 @@ def _create_body( self, context: str, question: str, - answer_length: Optional[str], + answer_length: Optional[AnswerLength], mode: Optional[str], ) -> Dict[str, Any]: return {"context": context, "question": question, "answerLength": answer_length, "mode": mode} diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py index 7495ea67..7e7c8fad 100644 --- a/ai21/clients/studio/resources/studio_embed.py +++ b/ai21/clients/studio/resources/studio_embed.py @@ -2,11 +2,12 @@ from ai21.clients.common.embed_base import Embed from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models.embed_type import EmbedType from ai21.models.responses.embed_response import EmbedResponse class StudioEmbed(StudioResource, Embed): - def create(self, texts: List[str], type: Optional[str] = None, **kwargs) -> EmbedResponse: + def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(texts=texts, type=type) response = self._post(url=url, body=body) From e119340989a5a7da8d14b9ae2861c25299ef1107 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 24 Jan 2024 17:55:21 +0200 Subject: [PATCH 14/27] test: Added some integration tests --- .../clients/studio/__init__.py | 0 .../clients/studio/test_answer.py | 51 +++++++ .../clients/studio/test_chat.py | 94 ++++++++++++ .../clients/studio/test_completion.py | 137 ++++++++++++++++++ .../clients/studio/test_embed.py | 29 ++++ .../clients/studio/test_gec.py | 37 +++++ 6 files changed, 348 insertions(+) create mode 100644 tests/integration_tests/clients/studio/__init__.py create mode 100644 tests/integration_tests/clients/studio/test_answer.py create mode 100644 tests/integration_tests/clients/studio/test_chat.py create mode 100644 tests/integration_tests/clients/studio/test_completion.py create mode 100644 tests/integration_tests/clients/studio/test_embed.py create mode 100644 tests/integration_tests/clients/studio/test_gec.py diff --git a/tests/integration_tests/clients/studio/__init__.py b/tests/integration_tests/clients/studio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/clients/studio/test_answer.py b/tests/integration_tests/clients/studio/test_answer.py new file mode 100644 index 00000000..a611085f --- /dev/null +++ b/tests/integration_tests/clients/studio/test_answer.py @@ -0,0 +1,51 @@ +import pytest +from ai21 import AI21Client +from ai21.models import AnswerLength, Mode + +_CONTEXT = ( + "Holland is a geographical region[2] and former province on the western coast of" + " the Netherlands.[2] From the " + "10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county " + "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " + "economic power, dominating the other provinces of the newly independent Dutch Republic." +) + + +@pytest.mark.parametrize( + ids=[ + "when_answer_is_in_context", + "when_answer_not_in_context", + ], + argnames=["question", "is_answer_in_context", "expected_answer_type"], + argvalues=[ + ("When did Holland become an economic power?", True, str), + ("Is the ocean blue?", False, None), + ], +) +def test_answer(question: str, is_answer_in_context: bool, expected_answer_type: type): + client = AI21Client() + response = client.answer.create( + context=_CONTEXT, + question=question, + answer_length=AnswerLength.LONG, + mode=Mode.FLEXIBLE, + ) + + assert response.answer_in_context == is_answer_in_context + if is_answer_in_context: + assert isinstance(response.answer, str) + else: + assert response.answer is None + + +def test_answer__when_answer_length(): + client = AI21Client() + response = client.answer.create( + context=_CONTEXT, + question="Can you please tell me everything you know about Holland?", + answer_length=AnswerLength.SHORT, + mode="flexible", + ) + print("--------\n") + print(response.answer) + print(len(response.answer)) diff --git a/tests/integration_tests/clients/studio/test_chat.py b/tests/integration_tests/clients/studio/test_chat.py new file mode 100644 index 00000000..70d26761 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_chat.py @@ -0,0 +1,94 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import ChatMessage, RoleType, Penalty, FinishReason + +_MODEL = "j2-ultra" +_MESSAGES = [ + ChatMessage( + text="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", + role=RoleType.USER, + ), +] +_SYSTEM = "You are a teacher in a public school" + + +def test_chat(): + num_results = 5 + messages = _MESSAGES + + client = AI21Client() + response = client.chat.create( + system=_SYSTEM, + messages=messages, + num_results=num_results, + max_tokens=64, + temperature=0.7, + min_tokens=1, + stop_sequences=["\n"], + top_p=0.3, + top_k_return=0, + model=_MODEL, + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + frequency_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + presence_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + ) + + assert response.outputs[0].role == RoleType.ASSISTANT + assert isinstance(response.outputs[0].text, str) + assert response.outputs[0].finish_reason == FinishReason(reason="endoftext") + + assert len(response.outputs) == num_results + + +@pytest.mark.parametrize( + ids=[ + "finish_reason_length", + "finish_reason_endoftext", + "finish_reason_stop_sequence", + ], + argnames=["max_tokens", "stop_sequences", "reason"], + argvalues=[ + (2, "##", "length"), + (100, "##", "endoftext"), + (20, ".", "stop"), + ], +) +def test_chat_when_finish_reason_defined__should_halt_on_expected_reason( + max_tokens: int, stop_sequences: str, reason: str +): + client = AI21Client() + response = client.chat.create( + messages=_MESSAGES, + system=_SYSTEM, + max_tokens=max_tokens, + model="j2-ultra", + temperature=1, + top_p=0, + num_results=1, + stop_sequences=[stop_sequences], + top_k_return=0, + ) + + assert response.outputs[0].finish_reason.reason == reason diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py new file mode 100644 index 00000000..440b5b89 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_completion.py @@ -0,0 +1,137 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import Penalty + +_PROMPT = ( + "The following is a conversation between a user of an eCommerce store and a user operation" + " associate called Max. Max is very kind and keen to help." + " The following are important points about the business policies:\n- " + "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" + " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " + "Hi there, happy to help!\nUser: Is there no way to return a product?" + " I got your blue T-Shirt size small but it doesn't fit.\n" + "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" + "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" + "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" + " associate called Max. Max is very kind and keen to help. The following are important points about" + " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" + 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' + "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" + " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" + "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" + " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." + " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" + "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" + " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" + " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" + "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" + "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" + "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" + " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" + " are important points about the business policies:\n- Delivery takes up to 5 days\n" + "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" + "User: Hi, I have a question for you" +) + + +def test_completion(): + num_results = 3 + + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model="j2-ultra", + temperature=1, + top_p=0.2, + top_k_return=0, + stop_sequences=["##"], + num_results=num_results, + custom_model=None, + epoch=1, + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + frequency_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + presence_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + ) + + assert response.prompt.text == _PROMPT + assert len(response.completions) == num_results + # Check the results aren't all the same + assert len(set([completion.data.text for completion in response.completions])) == num_results + for completion in response.completions: + assert isinstance(completion.data.text, str) + + +def test_completion_when_temperature_1_and_top_p_is_0__should_return_same_response(): + num_results = 5 + + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model="j2-ultra", + temperature=1, + top_p=0, + top_k_return=0, + num_results=num_results, + epoch=1, + ) + + assert response.prompt.text == _PROMPT + assert len(response.completions) == num_results + # Verify all results are the same + assert len(set([completion.data.text for completion in response.completions])) == 1 + + +@pytest.mark.parametrize( + ids=[ + "finish_reason_length", + "finish_reason_endoftext", + "finish_reason_stop_sequence", + ], + argnames=["max_tokens", "stop_sequences", "reason"], + argvalues=[ + (10, "##", "length"), + (100, "##", "endoftext"), + (50, "\n", "stop"), + ], +) +def test_completion_when_finish_reason_defined__should_halt_on_expected_reason( + max_tokens: int, stop_sequences: str, reason: str +): + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=max_tokens, + model="j2-ultra", + temperature=1, + top_p=0, + num_results=1, + stop_sequences=[stop_sequences], + top_k_return=0, + epoch=1, + ) + + assert response.completions[0].finish_reason.reason == reason diff --git a/tests/integration_tests/clients/studio/test_embed.py b/tests/integration_tests/clients/studio/test_embed.py new file mode 100644 index 00000000..ea260d76 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_embed.py @@ -0,0 +1,29 @@ +from typing import List + +import pytest +from ai21 import AI21Client +from ai21.models import EmbedType + +_TEXT_0 = "Holland is a geographical region and former province on the western coast of the Netherlands." +_TEXT_1 = "Germany is a country in Central Europe. It is the second-most populous country in Europe after Russia" + + +@pytest.mark.parametrize( + ids=[ + "when_single_text__should_return_single_embedding", + "when_multiple_text__should_return_multiple_embeddings", + ], + argnames=["texts"], + argvalues=[ + ([_TEXT_0],), + ([_TEXT_0, _TEXT_1],), + ], +) +def test_embed(texts: List[str]): + client = AI21Client() + response = client.embed.create( + texts=texts, + type=EmbedType.QUERY, + ) + + assert len(response.results) == len(texts) diff --git a/tests/integration_tests/clients/studio/test_gec.py b/tests/integration_tests/clients/studio/test_gec.py new file mode 100644 index 00000000..1503c15c --- /dev/null +++ b/tests/integration_tests/clients/studio/test_gec.py @@ -0,0 +1,37 @@ +import pytest +from ai21 import AI21Client +from ai21.models import CorrectionType + + +@pytest.mark.parametrize( + ids=[ + "should_fix_spelling", + "should_fix_grammar", + "should_fix_missing_word", + "should_fix_punctuation", + "should_fix_wrong_word", + # "should_fix_word_repetition", + ], + argnames=["text", "correction_type", "expected_suggestion"], + argvalues=[ + ("jazzz is music", CorrectionType.SPELLING, "Jazz"), + ("You am nice", CorrectionType.GRAMMAR, "are"), + ( + "He stared out the window, lost in thought, as the raindrops against the glass.", + CorrectionType.MISSING_WORD, + "raindrops fell against", + ), + ("He is a well known author.", CorrectionType.PUNCTUATION, "well-known"), + ("He is a dog-known author.", CorrectionType.WRONG_WORD, "well-known"), + # ( + # "The mountain was tall, and the tall mountain could be seen from miles away.", + # CorrectionType.WORD_REPETITION, + # "like", + # ), + ], +) +def test_gec(text: str, correction_type: CorrectionType, expected_suggestion: str): + client = AI21Client() + response = client.gec.create(text=text) + assert response.corrections[0].suggestion == expected_suggestion + assert response.corrections[0].correction_type == correction_type From e7b461a9ceaffc3c959e25c94245f298fdb067b2 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 25 Jan 2024 11:57:09 +0200 Subject: [PATCH 15/27] test: improvements --- .../clients/studio/test_answer.py | 15 +---------- .../clients/studio/test_improvements.py | 13 ++++++++++ .../clients/studio/test_paraphrase.py | 26 +++++++++++++++++++ 3 files changed, 40 insertions(+), 14 deletions(-) create mode 100644 tests/integration_tests/clients/studio/test_improvements.py create mode 100644 tests/integration_tests/clients/studio/test_paraphrase.py diff --git a/tests/integration_tests/clients/studio/test_answer.py b/tests/integration_tests/clients/studio/test_answer.py index a611085f..51dd4fa2 100644 --- a/tests/integration_tests/clients/studio/test_answer.py +++ b/tests/integration_tests/clients/studio/test_answer.py @@ -4,7 +4,7 @@ _CONTEXT = ( "Holland is a geographical region[2] and former province on the western coast of" - " the Netherlands.[2] From the " + " the Netherlands. From the " "10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county " "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " "economic power, dominating the other provinces of the newly independent Dutch Republic." @@ -36,16 +36,3 @@ def test_answer(question: str, is_answer_in_context: bool, expected_answer_type: assert isinstance(response.answer, str) else: assert response.answer is None - - -def test_answer__when_answer_length(): - client = AI21Client() - response = client.answer.create( - context=_CONTEXT, - question="Can you please tell me everything you know about Holland?", - answer_length=AnswerLength.SHORT, - mode="flexible", - ) - print("--------\n") - print(response.answer) - print(len(response.answer)) diff --git a/tests/integration_tests/clients/studio/test_improvements.py b/tests/integration_tests/clients/studio/test_improvements.py new file mode 100644 index 00000000..3488f781 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_improvements.py @@ -0,0 +1,13 @@ +from ai21 import AI21Client +from ai21.models import ImprovementType + + +def test_improvements(): + client = AI21Client() + response = client.improvements.create( + text="Affiliated with the profession of project management," + " I have ameliorated myself with a different set of hard skills as well as soft skills.", + types=[ImprovementType.FLUENCY], + ) + + assert len(response.improvements) > 0 diff --git a/tests/integration_tests/clients/studio/test_paraphrase.py b/tests/integration_tests/clients/studio/test_paraphrase.py new file mode 100644 index 00000000..ed9f7272 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_paraphrase.py @@ -0,0 +1,26 @@ +from ai21 import AI21Client +from ai21.models import ParaphraseStyleType + + +def test_paraphrase(): + client = AI21Client() + response = client.paraphrase.create( + text="The cat (Felis catus) is a domestic species of small carnivorous mammal", + style=ParaphraseStyleType.FORMAL, + start_index=0, + end_index=20, + ) + for suggestion in response.suggestions: + print(suggestion.text) + assert len(response.suggestions) > 0 + + +def test_paraphrase__when_start_and_end_index_is_small__should_not_return_suggestions(): + client = AI21Client() + response = client.paraphrase.create( + text="The cat (Felis catus) is a domestic species of small carnivorous mammal", + style=ParaphraseStyleType.GENERAL, + start_index=0, + end_index=5, + ) + assert len(response.suggestions) == 0 From dc183489948b0b9b073741189877fd75bc7ce9de Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sat, 27 Jan 2024 19:33:21 +0200 Subject: [PATCH 16/27] test: test_paraphrase.py --- .../clients/studio/test_paraphrase.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/integration_tests/clients/studio/test_paraphrase.py b/tests/integration_tests/clients/studio/test_paraphrase.py index ed9f7272..a7ba93aa 100644 --- a/tests/integration_tests/clients/studio/test_paraphrase.py +++ b/tests/integration_tests/clients/studio/test_paraphrase.py @@ -1,3 +1,5 @@ +import pytest + from ai21 import AI21Client from ai21.models import ParaphraseStyleType @@ -24,3 +26,26 @@ def test_paraphrase__when_start_and_end_index_is_small__should_not_return_sugges end_index=5, ) assert len(response.suggestions) == 0 + + +@pytest.mark.parametrize( + ids=["when_general", "when_casual", "when_long", "when_short", "when_formal"], + argnames=["style"], + argvalues=[ + (ParaphraseStyleType.GENERAL,), + (ParaphraseStyleType.CASUAL,), + (ParaphraseStyleType.LONG,), + (ParaphraseStyleType.SHORT,), + (ParaphraseStyleType.FORMAL,), + ], +) +def test_paraphrase_styles(style: ParaphraseStyleType): + client = AI21Client() + response = client.paraphrase.create( + text="Today is a beautiful day.", + style=style, + start_index=0, + end_index=25, + ) + + assert len(response.suggestions) > 0 From 067497d19ca35ad244a9d29bc4ad188e8123384e Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sat, 27 Jan 2024 19:41:24 +0200 Subject: [PATCH 17/27] fix: doc --- ai21/clients/common/summarize_base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ai21/clients/common/summarize_base.py b/ai21/clients/common/summarize_base.py index 85cec2f5..c74d4538 100644 --- a/ai21/clients/common/summarize_base.py +++ b/ai21/clients/common/summarize_base.py @@ -16,6 +16,14 @@ def create( summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: + """ + :param source: The input text, or URL of a web page to be summarized. + :param source_type: Either TEXT or URL + :param focus: Return only summaries focused on a topic of your choice. + :param summary_method: + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> SummarizeResponse: From a952acdcbdf80979f01b6cf3d0497f8a2909417e Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sat, 27 Jan 2024 19:41:39 +0200 Subject: [PATCH 18/27] fix: removed unused comment --- ai21/clients/studio/resources/studio_summarize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index b2b5f860..4180ff52 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -18,7 +18,6 @@ def create( summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: - # Make a summarize request to the AI21 API. Returns the response either as a string or a AI21Summarize object. body = self._create_body( source=source, source_type=source_type, From 5d703995e3eac444b30d95e37937cbe077644b64 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sat, 27 Jan 2024 20:11:41 +0200 Subject: [PATCH 19/27] test: test_summarize.py --- .../clients/studio/test_summarize.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/integration_tests/clients/studio/test_summarize.py diff --git a/tests/integration_tests/clients/studio/test_summarize.py b/tests/integration_tests/clients/studio/test_summarize.py new file mode 100644 index 00000000..7153c921 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_summarize.py @@ -0,0 +1,62 @@ +import pytest + +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType, SummaryMethod + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text__should_return_a_suggestion", + "when_source_is_url__should_return_a_suggestion", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.TEXT), + (_SOURCE_URL, DocumentType.URL), + ], +) +def test_summarize(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + response = client.summarize.create( + source=source, + source_type=source_type, + summary_method=SummaryMethod.SEGMENTS, + focus=focus, + ) + assert response.summary is not None + assert focus in response.summary + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_summarize_source_and_source_type_misalignment(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.summarize.create( + source=source, + source_type=source_type, + summary_method=SummaryMethod.SEGMENTS, + focus=focus, + ) From 36019845dcc74e0bafa3d857657647629f7147bf Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 28 Jan 2024 09:41:09 +0200 Subject: [PATCH 20/27] test: Added tests for test_summarize_by_segment.py --- .../clients/studio/test_summarize.py | 2 +- .../studio/test_summarize_by_segment.py | 66 +++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 tests/integration_tests/clients/studio/test_summarize_by_segment.py diff --git a/tests/integration_tests/clients/studio/test_summarize.py b/tests/integration_tests/clients/studio/test_summarize.py index 7153c921..6c7ae4e9 100644 --- a/tests/integration_tests/clients/studio/test_summarize.py +++ b/tests/integration_tests/clients/studio/test_summarize.py @@ -49,7 +49,7 @@ def test_summarize(source: str, source_type: DocumentType): (_SOURCE_URL, DocumentType.TEXT), ], ) -def test_summarize_source_and_source_type_misalignment(source: str, source_type: DocumentType): +def test_summarize__source_and_source_type_misalignment(source: str, source_type: DocumentType): focus = "Holland" client = AI21Client() diff --git a/tests/integration_tests/clients/studio/test_summarize_by_segment.py b/tests/integration_tests/clients/studio/test_summarize_by_segment.py new file mode 100644 index 00000000..f39dd308 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_summarize_by_segment.py @@ -0,0 +1,66 @@ +import pytest + +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the Netherlands. + From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +def test_summarize_by_segment__when_text__should_return_response(): + client = AI21Client() + response = client.summarize_by_segment.create( + source=_SOURCE_TEXT, + source_type=DocumentType.TEXT, + focus="Holland", + ) + assert isinstance(response.segments[0].segment_text, str) + assert response.segments[0].segment_html is None + assert isinstance(response.segments[0].summary, str) + assert len(response.segments[0].highlights) > 0 + assert response.segments[0].segment_type == "normal_text" + assert response.segments[0].has_summary + + +def test_summarize_by_segment__when_url__should_return_response(): + client = AI21Client() + response = client.summarize_by_segment.create( + source=_SOURCE_URL, + source_type=DocumentType.URL, + focus="Holland", + ) + assert isinstance(response.segments[0].segment_text, str) + assert isinstance(response.segments[0].segment_html, str) + assert isinstance(response.segments[0].summary, str) + assert response.segments[0].segment_type == "normal_text" + assert len(response.segments[0].highlights) > 0 + assert response.segments[0].has_summary + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_summarize_by_segment__source_and_source_type_misalignment(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.summarize_by_segment.create( + source=source, + source_type=source_type, + focus=focus, + ) From 85669a286c9563f670973d541254db5b591d95b9 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 28 Jan 2024 11:26:03 +0200 Subject: [PATCH 21/27] test: test_segmentation.py --- .../clients/studio/test_segmentation.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/integration_tests/clients/studio/test_segmentation.py diff --git a/tests/integration_tests/clients/studio/test_segmentation.py b/tests/integration_tests/clients/studio/test_segmentation.py new file mode 100644 index 00000000..ede8707c --- /dev/null +++ b/tests/integration_tests/clients/studio/test_segmentation.py @@ -0,0 +1,55 @@ +import pytest +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text__should_return_a_segments", + "when_source_is_url__should_return_a_segments", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.TEXT), + (_SOURCE_URL, DocumentType.URL), + ], +) +def test_segmentation(source: str, source_type: DocumentType): + client = AI21Client() + + response = client.segmentation.create( + source=source, + source_type=source_type, + ) + + assert isinstance(response.segments[0].segment_text, str) + assert response.segments[0].segment_type is not None + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + # "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + # (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_segmentation__source_and_source_type_misalignment(source: str, source_type: DocumentType): + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.segmentation.create( + source=source, + source_type=source_type, + ) From dfc1f7731491dc0101136f761f4f74aaac27757b Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 28 Jan 2024 14:13:23 +0200 Subject: [PATCH 22/27] fix: file id in library response --- ai21/models/responses/library_answer_response.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ai21/models/responses/library_answer_response.py b/ai21/models/responses/library_answer_response.py index 36341eda..28fab165 100644 --- a/ai21/models/responses/library_answer_response.py +++ b/ai21/models/responses/library_answer_response.py @@ -6,7 +6,7 @@ @dataclass class SourceDocument(AI21BaseModelMixin): - field_id: str + file_id: str name: str highlights: List[str] public_url: Optional[str] = None From 079b02ee4f2d559f99250a002418c7c4799ffe39 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 28 Jan 2024 14:13:44 +0200 Subject: [PATCH 23/27] fix: example for library --- examples/studio/library.py | 8 +++++++- examples/studio/library_answer.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/studio/library.py b/examples/studio/library.py index d693d697..ca8e4840 100644 --- a/examples/studio/library.py +++ b/examples/studio/library.py @@ -22,7 +22,12 @@ def validate_file_deleted(): file_path = os.getcwd() path = os.path.join(file_path, file_name) -file_utils.create_file(file_path, file_name, content="test content" * 100) +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" +file_utils.create_file(file_path, file_name, content=_SOURCE_TEXT) file_id = client.library.files.create( file_path=path, @@ -31,6 +36,7 @@ def validate_file_deleted(): public_url="www.example.com", ) print(file_id) + files = client.library.files.list() print(files) uploaded_file = client.library.files.get(file_id) diff --git a/examples/studio/library_answer.py b/examples/studio/library_answer.py index 54b2bb1d..6ec3f092 100644 --- a/examples/studio/library_answer.py +++ b/examples/studio/library_answer.py @@ -2,5 +2,5 @@ client = AI21Client() -response = client.library.answer.create(question="Where is Thailand?") +response = client.library.answer.create(question="Tell me something about Holland") print(response) From ae67c6e886079d8ac9abe11d5802001bbb76fd22 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 30 Jan 2024 10:32:38 +0200 Subject: [PATCH 24/27] docs: docstrings --- ai21/clients/common/answer_base.py | 9 +++++++++ ai21/clients/common/chat_base.py | 19 +++++++++++++++++++ ai21/clients/common/completion_base.py | 18 ++++++++++++++++++ ai21/clients/common/custom_model_base.py | 10 ++++++++++ ai21/clients/common/dataset_base.py | 11 +++++++++++ ai21/clients/common/embed_base.py | 8 ++++++++ ai21/clients/common/gec_base.py | 6 ++++++ ai21/clients/common/improvements_base.py | 7 +++++++ ai21/clients/common/paraphrase_base.py | 10 ++++++++++ ai21/clients/common/segmentation_base.py | 7 +++++++ ai21/clients/common/summarize_base.py | 2 +- .../common/summarize_by_segment_base.py | 8 ++++++++ 12 files changed, 114 insertions(+), 1 deletion(-) diff --git a/ai21/clients/common/answer_base.py b/ai21/clients/common/answer_base.py index 03b6c71d..1821db9a 100644 --- a/ai21/clients/common/answer_base.py +++ b/ai21/clients/common/answer_base.py @@ -17,6 +17,15 @@ def create( mode: Optional[Mode] = None, **kwargs, ) -> AnswerResponse: + """ + + :param context: A string containing the document context for which the question will be answered + :param question: A string containing the question to be answered based on the provided context. + :param answer_length: Approximate length of the answer in words. + :param mode: + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> AnswerResponse: diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index e73dba3c..03037491 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -28,6 +28,25 @@ def create( count_penalty: Optional[Penalty] = None, **kwargs, ) -> ChatResponse: + """ + + :param model: model type you wish to interact with + :param messages: A sequence of messages ingested by the model, which then returns the assistant's response + :param system: Offers the model overarching guidance on its response approach, encapsulating context, tone, + guardrails, and more + :param max_tokens: The maximum number of tokens to generate per result + :param num_results: Number of completions to sample and return. + :param min_tokens: The minimum number of tokens to generate per result. + :param temperature: A value controlling the "creativity" of the model's responses. + :param top_p: A value controlling the diversity of the model's responses. + :param top_k_return: The number of top-scoring tokens to consider for each generation step. + :param stop_sequences: Stops decoding if any of the strings is generated + :param frequency_penalty: A penalty applied to tokens that are frequently generated. + :param presence_penalty: A penalty applied to tokens that are already present in the prompt. + :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse: diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index a2bc8c3d..06fef7d8 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -28,6 +28,24 @@ def create( epoch: Optional[int] = None, **kwargs, ) -> CompletionsResponse: + """ + :param model: model type you wish to interact with + :param prompt: Text for model to complete + :param max_tokens: The maximum number of tokens to generate per result + :param num_results: Number of completions to sample and return. + :param min_tokens: The minimum number of tokens to generate per result. + :param temperature: A value controlling the "creativity" of the model's responses. + :param top_p: A value controlling the diversity of the model's responses. + :param top_k_return: The number of top-scoring tokens to consider for each generation step. + :param custom_model: + :param stop_sequences: Stops decoding if any of the strings is generated + :param frequency_penalty: A penalty applied to tokens that are frequently generated. + :param presence_penalty: A penalty applied to tokens that are already present in the prompt. + :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses + :param epoch: + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> CompletionsResponse: diff --git a/ai21/clients/common/custom_model_base.py b/ai21/clients/common/custom_model_base.py index 7d3b55ae..303b1adf 100644 --- a/ai21/clients/common/custom_model_base.py +++ b/ai21/clients/common/custom_model_base.py @@ -18,6 +18,16 @@ def create( num_epochs: Optional[int] = None, **kwargs, ) -> None: + """ + + :param dataset_id: The dataset you want to train your model on. + :param model_name: The name of your trained model + :param model_type: The type of model to train. + :param learning_rate: The learning rate used for training. + :param num_epochs: Number of epochs for training + :param kwargs: + :return: + """ pass @abstractmethod diff --git a/ai21/clients/common/dataset_base.py b/ai21/clients/common/dataset_base.py index 9fa57f85..732dee39 100644 --- a/ai21/clients/common/dataset_base.py +++ b/ai21/clients/common/dataset_base.py @@ -19,6 +19,17 @@ def create( split_ratio: Optional[float] = None, **kwargs, ): + """ + + :param file_path: Local path to dataset + :param dataset_name: Dataset name. Must be unique + :param selected_columns: Mapping of the columns in the dataset file to prompt and completion columns. + :param approve_whitespace_correction: Automatically correct examples that violate best practices + :param delete_long_rows: Allow removal of examples where prompt + completion lengths exceeds 2047 tokens + :param split_ratio: + :param kwargs: + :return: + """ pass @abstractmethod diff --git a/ai21/clients/common/embed_base.py b/ai21/clients/common/embed_base.py index baadd4ec..aaf9363e 100644 --- a/ai21/clients/common/embed_base.py +++ b/ai21/clients/common/embed_base.py @@ -10,6 +10,14 @@ class Embed(ABC): @abstractmethod def create(self, texts: List[str], *, type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: + """ + + :param texts: A list of strings, each representing a document or segment of text to be embedded. + :param type: For retrieval/search use cases, indicates whether the texts that were + sent are segments or the query. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> EmbedResponse: diff --git a/ai21/clients/common/gec_base.py b/ai21/clients/common/gec_base.py index 8de743e2..091e6427 100644 --- a/ai21/clients/common/gec_base.py +++ b/ai21/clients/common/gec_base.py @@ -9,6 +9,12 @@ class GEC(ABC): @abstractmethod def create(self, text: str, **kwargs) -> GECResponse: + """ + + :param text: The input text to be corrected. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> GECResponse: diff --git a/ai21/clients/common/improvements_base.py b/ai21/clients/common/improvements_base.py index df912e1d..df13fe58 100644 --- a/ai21/clients/common/improvements_base.py +++ b/ai21/clients/common/improvements_base.py @@ -10,6 +10,13 @@ class Improvements(ABC): @abstractmethod def create(self, text: str, types: List[ImprovementType], **kwargs) -> ImprovementsResponse: + """ + + :param text: The input text to be improved. + :param types: Types of improvements to apply. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> ImprovementsResponse: diff --git a/ai21/clients/common/paraphrase_base.py b/ai21/clients/common/paraphrase_base.py index 917cdd75..3c01bbb7 100644 --- a/ai21/clients/common/paraphrase_base.py +++ b/ai21/clients/common/paraphrase_base.py @@ -18,6 +18,16 @@ def create( end_index: Optional[int] = None, **kwargs, ) -> ParaphraseResponse: + """ + + :param text: The input text to be paraphrased. + :param style: Controls length and tone + :param start_index: Specifies the starting position of the paraphrasing process in the given text + :param end_index: specifies the position of the last character to be paraphrased, including the character + following it. If the parameter is not provided, the default value is set to the length of the given text. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> ParaphraseResponse: diff --git a/ai21/clients/common/segmentation_base.py b/ai21/clients/common/segmentation_base.py index c4f658a9..97c74104 100644 --- a/ai21/clients/common/segmentation_base.py +++ b/ai21/clients/common/segmentation_base.py @@ -10,6 +10,13 @@ class Segmentation(ABC): @abstractmethod def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: + """ + + :param source: Raw input text, or URL of a web page. + :param source_type: The type of the source - either TEXT or URL. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> SegmentationResponse: diff --git a/ai21/clients/common/summarize_base.py b/ai21/clients/common/summarize_base.py index c74d4538..2a70dfcd 100644 --- a/ai21/clients/common/summarize_base.py +++ b/ai21/clients/common/summarize_base.py @@ -19,7 +19,7 @@ def create( """ :param source: The input text, or URL of a web page to be summarized. :param source_type: Either TEXT or URL - :param focus: Return only summaries focused on a topic of your choice. + :param focus: Summaries focused on a topic of your choice. :param summary_method: :param kwargs: :return: diff --git a/ai21/clients/common/summarize_by_segment_base.py b/ai21/clients/common/summarize_by_segment_base.py index 40a4abfa..236337de 100644 --- a/ai21/clients/common/summarize_by_segment_base.py +++ b/ai21/clients/common/summarize_by_segment_base.py @@ -19,6 +19,14 @@ def create( focus: Optional[str] = None, **kwargs, ) -> SummarizeBySegmentResponse: + """ + + :param source: The input text, or URL of a web page to be summarized. + :param source_type: Either TEXT or URL + :param focus: Summaries focused on a topic of your choice. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> SummarizeBySegmentResponse: From 42505af8fcd0063609868ab217087fa9d7516335 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 30 Jan 2024 10:33:36 +0200 Subject: [PATCH 25/27] fix: question --- examples/studio/library_answer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/studio/library_answer.py b/examples/studio/library_answer.py index 6ec3f092..20d46402 100644 --- a/examples/studio/library_answer.py +++ b/examples/studio/library_answer.py @@ -2,5 +2,5 @@ client = AI21Client() -response = client.library.answer.create(question="Tell me something about Holland") +response = client.library.answer.create(question="Can you tell me something about Holland?") print(response) From 224d8ce680a21c69886b91a391955ba5980f6965 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 30 Jan 2024 10:41:46 +0200 Subject: [PATCH 26/27] fix: CR --- .../clients/studio/test_completion.py | 41 ++++--------------- .../clients/studio/test_gec.py | 6 --- 2 files changed, 8 insertions(+), 39 deletions(-) diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py index 440b5b89..9938fa2a 100644 --- a/tests/integration_tests/clients/studio/test_completion.py +++ b/tests/integration_tests/clients/studio/test_completion.py @@ -3,36 +3,11 @@ from ai21 import AI21Client from ai21.models import Penalty -_PROMPT = ( - "The following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- " - "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" - " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " - "Hi there, happy to help!\nUser: Is there no way to return a product?" - " I got your blue T-Shirt size small but it doesn't fit.\n" - "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" - "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" - "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help. The following are important points about" - " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" - 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' - "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" - " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" - "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" - " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" - " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" - " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" - "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" - "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" - "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" - " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" - " are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" - "User: Hi, I have a question for you" -) +_PROMPT = """ +User: Haven't received a confirmation email for my order #12345. +Assistant: I'm sorry to hear that. I'll look into it right away. +User: Can you please let me know when I can expect to receive it? +""" def test_completion(): @@ -43,9 +18,9 @@ def test_completion(): prompt=_PROMPT, max_tokens=64, model="j2-ultra", - temperature=1, + temperature=0.7, top_p=0.2, - top_k_return=0, + top_k_return=0.2, stop_sequences=["##"], num_results=num_results, custom_model=None, @@ -79,7 +54,7 @@ def test_completion(): assert response.prompt.text == _PROMPT assert len(response.completions) == num_results # Check the results aren't all the same - assert len(set([completion.data.text for completion in response.completions])) == num_results + assert len([completion.data.text for completion in response.completions]) == num_results for completion in response.completions: assert isinstance(completion.data.text, str) diff --git a/tests/integration_tests/clients/studio/test_gec.py b/tests/integration_tests/clients/studio/test_gec.py index 1503c15c..51418cf1 100644 --- a/tests/integration_tests/clients/studio/test_gec.py +++ b/tests/integration_tests/clients/studio/test_gec.py @@ -10,7 +10,6 @@ "should_fix_missing_word", "should_fix_punctuation", "should_fix_wrong_word", - # "should_fix_word_repetition", ], argnames=["text", "correction_type", "expected_suggestion"], argvalues=[ @@ -23,11 +22,6 @@ ), ("He is a well known author.", CorrectionType.PUNCTUATION, "well-known"), ("He is a dog-known author.", CorrectionType.WRONG_WORD, "well-known"), - # ( - # "The mountain was tall, and the tall mountain could be seen from miles away.", - # CorrectionType.WORD_REPETITION, - # "like", - # ), ], ) def test_gec(text: str, correction_type: CorrectionType, expected_suggestion: str): From 93caebda36c9c9b988e02261e524d71eded46ded Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 30 Jan 2024 10:49:04 +0200 Subject: [PATCH 27/27] test: Added tests to segment type in embed --- .../clients/studio/test_embed.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/integration_tests/clients/studio/test_embed.py b/tests/integration_tests/clients/studio/test_embed.py index ea260d76..8fc77dc5 100644 --- a/tests/integration_tests/clients/studio/test_embed.py +++ b/tests/integration_tests/clients/studio/test_embed.py @@ -7,23 +7,31 @@ _TEXT_0 = "Holland is a geographical region and former province on the western coast of the Netherlands." _TEXT_1 = "Germany is a country in Central Europe. It is the second-most populous country in Europe after Russia" +_SEGMENT_0 = "The sun sets behind the mountains," +_SEGMENT_1 = "casting a warm glow over" +_SEGMENT_2 = "the city of Amsterdam." + @pytest.mark.parametrize( ids=[ - "when_single_text__should_return_single_embedding", - "when_multiple_text__should_return_multiple_embeddings", + "when_single_text_and_query__should_return_single_embedding", + "when_multiple_text_and_query__should_return_multiple_embeddings", + "when_single_text_and_segment__should_return_single_embedding", + "when_multiple_text_and_segment__should_return_multiple_embeddings", ], - argnames=["texts"], + argnames=["texts", "type"], argvalues=[ - ([_TEXT_0],), - ([_TEXT_0, _TEXT_1],), + ([_TEXT_0], EmbedType.QUERY), + ([_TEXT_0, _TEXT_1], EmbedType.QUERY), + ([_SEGMENT_0], EmbedType.SEGMENT), + ([_SEGMENT_0, _SEGMENT_1, _SEGMENT_2], EmbedType.SEGMENT), ], ) -def test_embed(texts: List[str]): +def test_embed(texts: List[str], type: EmbedType): client = AI21Client() response = client.embed.create( texts=texts, - type=EmbedType.QUERY, + type=type, ) assert len(response.results) == len(texts)