From 2462fc19d96584f8dbac28758cc95f28311fe829 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 18 Feb 2024 11:21:04 +0200 Subject: [PATCH 1/2] fix: Removed answer_length and mode from answer --- ai21/clients/common/answer_base.py | 13 +++---------- .../clients/sagemaker/resources/sagemaker_answer.py | 9 ++------- ai21/clients/studio/resources/studio_answer.py | 9 ++------- ai21/clients/studio/resources/studio_library.py | 6 +----- ai21/models/__init__.py | 4 ---- ai21/models/answer_length.py | 7 ------- ai21/models/mode.py | 6 ------ .../integration_tests/clients/studio/test_answer.py | 3 --- .../studio/resources/test_studio_resources.py | 2 -- 9 files changed, 8 insertions(+), 51 deletions(-) delete mode 100644 ai21/models/answer_length.py delete mode 100644 ai21/models/mode.py diff --git a/ai21/clients/common/answer_base.py b/ai21/clients/common/answer_base.py index 43188d8d..4ff3706c 100644 --- a/ai21/clients/common/answer_base.py +++ b/ai21/clients/common/answer_base.py @@ -1,7 +1,7 @@ from abc import ABC -from typing import Optional, Any, Dict +from typing import Any, Dict -from ai21.models import Mode, AnswerLength, AnswerResponse +from ai21.models import AnswerResponse class Answer(ABC): @@ -11,17 +11,12 @@ def create( self, context: str, question: str, - *, - answer_length: Optional[AnswerLength] = None, - mode: Optional[Mode] = None, **kwargs, ) -> AnswerResponse: """ :param context: A string containing the document context for which the question will be answered :param question: A string containing the question to be answered based on the provided context. - :param answer_length: Approximate length of the answer in words. - :param mode: :param kwargs: :return: """ @@ -34,7 +29,5 @@ def _create_body( self, context: str, question: str, - answer_length: Optional[AnswerLength], - mode: Optional[str], ) -> Dict[str, Any]: - return {"context": context, "question": question, "answerLength": answer_length, "mode": mode} + return {"context": context, "question": question} diff --git a/ai21/clients/sagemaker/resources/sagemaker_answer.py b/ai21/clients/sagemaker/resources/sagemaker_answer.py index 03760daf..d4a6ceb5 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_answer.py +++ b/ai21/clients/sagemaker/resources/sagemaker_answer.py @@ -1,8 +1,6 @@ -from typing import Optional - from ai21.clients.common.answer_base import Answer from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource -from ai21.models import AnswerResponse, AnswerLength, Mode +from ai21.models import AnswerResponse class SageMakerAnswer(SageMakerResource, Answer): @@ -10,12 +8,9 @@ def create( self, context: str, question: str, - *, - answer_length: Optional[AnswerLength] = None, - mode: Optional[Mode] = None, **kwargs, ) -> AnswerResponse: - body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode) + body = self._create_body(context=context, question=question) response = self._invoke(body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index 6fe86c5e..50b45cb7 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -1,8 +1,6 @@ -from typing import Optional - from ai21.clients.common.answer_base import Answer from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models import AnswerLength, Mode, AnswerResponse +from ai21.models import AnswerResponse class StudioAnswer(StudioResource, Answer): @@ -10,14 +8,11 @@ def create( self, context: str, question: str, - *, - answer_length: Optional[AnswerLength] = None, - mode: Optional[Mode] = None, **kwargs, ) -> AnswerResponse: url = f"{self._client.get_base_url()}/{self._module_name}" - body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode) + body = self._create_body(context=context, question=question) response = self._post(url=url, body=body) diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index 782fad51..c028cfcb 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -2,7 +2,7 @@ from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models import Mode, AnswerLength, FileResponse, LibraryAnswerResponse, LibrarySearchResponse +from ai21.models import FileResponse, LibraryAnswerResponse, LibrarySearchResponse class StudioLibrary(StudioResource): @@ -107,8 +107,6 @@ def create( path: Optional[str] = None, field_ids: Optional[List[str]] = None, labels: Optional[List[str]] = None, - answer_length: Optional[AnswerLength] = None, - mode: Optional[Mode] = None, **kwargs, ) -> LibraryAnswerResponse: url = f"{self._client.get_base_url()}/{self._module_name}" @@ -117,8 +115,6 @@ def create( "path": path, "fieldIds": field_ids, "labels": labels, - "answerLength": answer_length, - "mode": mode, } raw_response = self._post(url=url, body=body) return LibraryAnswerResponse.from_dict(raw_response) diff --git a/ai21/models/__init__.py b/ai21/models/__init__.py index 5e5877d8..174215f5 100644 --- a/ai21/models/__init__.py +++ b/ai21/models/__init__.py @@ -1,9 +1,7 @@ -from ai21.models.answer_length import AnswerLength from ai21.models.chat_message import ChatMessage from ai21.models.document_type import DocumentType from ai21.models.embed_type import EmbedType from ai21.models.improvement_type import ImprovementType -from ai21.models.mode import Mode from ai21.models.paraphrase_style_type import ParaphraseStyleType from ai21.models.penalty import Penalty from ai21.models.responses.answer_response import AnswerResponse @@ -32,8 +30,6 @@ __all__ = [ - "AnswerLength", - "Mode", "ChatMessage", "RoleType", "Penalty", diff --git a/ai21/models/answer_length.py b/ai21/models/answer_length.py deleted file mode 100644 index 8adc82fd..00000000 --- a/ai21/models/answer_length.py +++ /dev/null @@ -1,7 +0,0 @@ -from enum import Enum - - -class AnswerLength(str, Enum): - SHORT = "short" - MEDIUM = "medium" - LONG = "long" diff --git a/ai21/models/mode.py b/ai21/models/mode.py deleted file mode 100644 index e3e49347..00000000 --- a/ai21/models/mode.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - - -class Mode(str, Enum): - FLEXIBLE = "flexible" - STRICT = "strict" diff --git a/tests/integration_tests/clients/studio/test_answer.py b/tests/integration_tests/clients/studio/test_answer.py index 51dd4fa2..1e83ee49 100644 --- a/tests/integration_tests/clients/studio/test_answer.py +++ b/tests/integration_tests/clients/studio/test_answer.py @@ -1,6 +1,5 @@ import pytest from ai21 import AI21Client -from ai21.models import AnswerLength, Mode _CONTEXT = ( "Holland is a geographical region[2] and former province on the western coast of" @@ -27,8 +26,6 @@ def test_answer(question: str, is_answer_in_context: bool, expected_answer_type: response = client.answer.create( context=_CONTEXT, question=question, - answer_length=AnswerLength.LONG, - mode=Mode.FLEXIBLE, ) assert response.answer_in_context == is_answer_in_context diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index cea0df9d..eac6f274 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -96,9 +96,7 @@ def test__create__when_pass_kwargs__should_not_pass_to_request(self, mock_ai21_s method="POST", url=_BASE_URL + "/answer", params={ - "answerLength": None, "context": _DUMMY_CONTEXT, - "mode": None, "question": _DUMMY_QUESTION, }, files=None, From 470272dfd8f4016e84c1ec53fb95e709fd27b2fd Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 18 Feb 2024 11:23:07 +0200 Subject: [PATCH 2/2] test: Fixed test --- tests/unittests/clients/studio/resources/conftest.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 1849cf42..00e2088d 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -60,9 +60,7 @@ def get_studio_answer(): {"context": _DUMMY_CONTEXT, "question": _DUMMY_QUESTION}, "answer", { - "answerLength": None, "context": _DUMMY_CONTEXT, - "mode": None, "question": _DUMMY_QUESTION, }, AnswerResponse(id="some-id", answer_in_context=True, answer="42"),