Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions ai21/clients/common/answer_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC
from typing import Optional, Any, Dict
from typing import Any, Dict

from ai21.models import Mode, AnswerLength, AnswerResponse
from ai21.models import AnswerResponse


class Answer(ABC):
Expand All @@ -11,17 +11,12 @@ def create(
self,
context: str,
question: str,
*,
answer_length: Optional[AnswerLength] = None,
mode: Optional[Mode] = None,
**kwargs,
) -> AnswerResponse:
"""

:param context: A string containing the document context for which the question will be answered
:param question: A string containing the question to be answered based on the provided context.
:param answer_length: Approximate length of the answer in words.
:param mode:
:param kwargs:
:return:
"""
Expand All @@ -34,7 +29,5 @@ def _create_body(
self,
context: str,
question: str,
answer_length: Optional[AnswerLength],
mode: Optional[str],
) -> Dict[str, Any]:
return {"context": context, "question": question, "answerLength": answer_length, "mode": mode}
return {"context": context, "question": question}
9 changes: 2 additions & 7 deletions ai21/clients/sagemaker/resources/sagemaker_answer.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
from typing import Optional

from ai21.clients.common.answer_base import Answer
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.models import AnswerResponse, AnswerLength, Mode
from ai21.models import AnswerResponse


class SageMakerAnswer(SageMakerResource, Answer):
def create(
self,
context: str,
question: str,
*,
answer_length: Optional[AnswerLength] = None,
mode: Optional[Mode] = None,
**kwargs,
) -> AnswerResponse:
body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode)
body = self._create_body(context=context, question=question)
response = self._invoke(body)

return self._json_to_response(response)
9 changes: 2 additions & 7 deletions ai21/clients/studio/resources/studio_answer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
from typing import Optional

from ai21.clients.common.answer_base import Answer
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.models import AnswerLength, Mode, AnswerResponse
from ai21.models import AnswerResponse


class StudioAnswer(StudioResource, Answer):
def create(
self,
context: str,
question: str,
*,
answer_length: Optional[AnswerLength] = None,
mode: Optional[Mode] = None,
**kwargs,
) -> AnswerResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"

body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode)
body = self._create_body(context=context, question=question)

response = self._post(url=url, body=body)

Expand Down
6 changes: 1 addition & 5 deletions ai21/clients/studio/resources/studio_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ai21.ai21_http_client import AI21HTTPClient
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.models import Mode, AnswerLength, FileResponse, LibraryAnswerResponse, LibrarySearchResponse
from ai21.models import FileResponse, LibraryAnswerResponse, LibrarySearchResponse


class StudioLibrary(StudioResource):
Expand Down Expand Up @@ -107,8 +107,6 @@ def create(
path: Optional[str] = None,
field_ids: Optional[List[str]] = None,
labels: Optional[List[str]] = None,
answer_length: Optional[AnswerLength] = None,
mode: Optional[Mode] = None,
**kwargs,
) -> LibraryAnswerResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
Expand All @@ -117,8 +115,6 @@ def create(
"path": path,
"fieldIds": field_ids,
"labels": labels,
"answerLength": answer_length,
"mode": mode,
}
raw_response = self._post(url=url, body=body)
return LibraryAnswerResponse.from_dict(raw_response)
4 changes: 0 additions & 4 deletions ai21/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from ai21.models.answer_length import AnswerLength
from ai21.models.chat_message import ChatMessage
from ai21.models.document_type import DocumentType
from ai21.models.embed_type import EmbedType
from ai21.models.improvement_type import ImprovementType
from ai21.models.mode import Mode
from ai21.models.paraphrase_style_type import ParaphraseStyleType
from ai21.models.penalty import Penalty
from ai21.models.responses.answer_response import AnswerResponse
Expand Down Expand Up @@ -32,8 +30,6 @@


__all__ = [
"AnswerLength",
"Mode",
"ChatMessage",
"RoleType",
"Penalty",
Expand Down
7 changes: 0 additions & 7 deletions ai21/models/answer_length.py

This file was deleted.

6 changes: 0 additions & 6 deletions ai21/models/mode.py

This file was deleted.

3 changes: 0 additions & 3 deletions tests/integration_tests/clients/studio/test_answer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
from ai21 import AI21Client
from ai21.models import AnswerLength, Mode

_CONTEXT = (
"Holland is a geographical region[2] and former province on the western coast of"
Expand All @@ -27,8 +26,6 @@ def test_answer(question: str, is_answer_in_context: bool, expected_answer_type:
response = client.answer.create(
context=_CONTEXT,
question=question,
answer_length=AnswerLength.LONG,
mode=Mode.FLEXIBLE,
)

assert response.answer_in_context == is_answer_in_context
Expand Down
2 changes: 0 additions & 2 deletions tests/unittests/clients/studio/resources/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def get_studio_answer():
{"context": _DUMMY_CONTEXT, "question": _DUMMY_QUESTION},
"answer",
{
"answerLength": None,
"context": _DUMMY_CONTEXT,
"mode": None,
"question": _DUMMY_QUESTION,
},
AnswerResponse(id="some-id", answer_in_context=True, answer="42"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ def test__create__when_pass_kwargs__should_not_pass_to_request(self, mock_ai21_s
method="POST",
url=_BASE_URL + "/answer",
params={
"answerLength": None,
"context": _DUMMY_CONTEXT,
"mode": None,
"question": _DUMMY_QUESTION,
},
files=None,
Expand Down