Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ai21/clients/common/maestro/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DEFAULT_RUN_POLL_TIMEOUT,
Requirement,
Budget,
OutputOptions,
)
from ai21.types import NOT_GIVEN, NotGiven
from ai21.utils.typing import remove_not_given
Expand All @@ -30,6 +31,7 @@ def _create_body(
context: Dict[str, Any] | NotGiven,
requirements: List[Requirement] | NotGiven,
budget: Budget | NotGiven,
include: List[OutputOptions] | NotGiven,
**kwargs,
) -> dict:
return remove_not_given(
Expand All @@ -41,6 +43,7 @@ def _create_body(
"context": context,
"requirements": requirements,
"budget": budget,
"include": include,
**kwargs,
}
)
Expand All @@ -56,6 +59,7 @@ def create(
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
**kwargs,
) -> RunResponse:
pass
Expand All @@ -79,6 +83,7 @@ def create_and_poll(
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
**kwargs,
Expand Down
9 changes: 9 additions & 0 deletions ai21/clients/studio/resources/maestro/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DEFAULT_RUN_POLL_TIMEOUT,
Requirement,
Budget,
OutputOptions,
)
from ai21.types import NotGiven, NOT_GIVEN

Expand All @@ -31,6 +32,7 @@ def create(
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
**kwargs,
) -> RunResponse:
body = self._create_body(
Expand All @@ -41,6 +43,7 @@ def create(
context=context,
requirements=requirements,
budget=budget,
include=include,
**kwargs,
)

Expand Down Expand Up @@ -76,6 +79,7 @@ def create_and_poll(
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
**kwargs,
Expand All @@ -88,6 +92,7 @@ def create_and_poll(
context=context,
requirements=requirements,
budget=budget,
include=include,
**kwargs,
)

Expand All @@ -105,6 +110,7 @@ async def create(
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
**kwargs,
) -> RunResponse:
body = self._create_body(
Expand All @@ -115,6 +121,7 @@ async def create(
context=context,
requirements=requirements,
budget=budget,
include=include,
**kwargs,
)

Expand Down Expand Up @@ -150,6 +157,7 @@ async def create_and_poll(
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
**kwargs,
Expand All @@ -162,6 +170,7 @@ async def create_and_poll(
context=context,
requirements=requirements,
budget=budget,
include=include,
**kwargs,
)

Expand Down
16 changes: 14 additions & 2 deletions ai21/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@
ConversationalRagSource,
)
from ai21.models.responses.file_response import FileResponse
from ai21.models.maestro.run import Requirement, Budget, Tool, ToolResources

from ai21.models.maestro.run import (
Requirement,
Budget,
Tool,
ToolResources,
DataSources,
FileSearchResult,
WebSearchResult,
OutputOptions,
)

__all__ = [
"ChatMessage",
Expand All @@ -26,4 +34,8 @@
"Budget",
"Tool",
"ToolResources",
"DataSources",
"FileSearchResult",
"WebSearchResult",
"OutputOptions",
]
27 changes: 24 additions & 3 deletions ai21/models/maestro/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TypedDict, Literal, List, Optional, Any, Set, Dict, Type, Union

from typing import Literal, List, Optional, Any, Set, Dict, Type, Union
from typing_extensions import TypedDict
from pydantic import BaseModel

from ai21.models.ai21_base_model import AI21BaseModel
Expand All @@ -8,6 +8,7 @@
Role = Literal["user", "assistant"]
RunStatus = Literal["completed", "failed", "in_progress", "requires_action"]
ToolType = Literal["file_search", "web_search"]
OutputOptions = Literal["data_sources"]
PrimitiveTypes = Union[Type[str], Type[int], Type[float], Type[bool]]
PrimitiveLists = Type[List[PrimitiveTypes]]
OutputType = Union[Type[BaseModel], PrimitiveTypes, Dict[str, Any]]
Expand Down Expand Up @@ -40,12 +41,32 @@ class ToolResources(TypedDict, total=False):
web_search: Optional[WebSearchToolResource]


class Requirement(TypedDict):
class Requirement(TypedDict, total=False):
name: str
description: str


class FileSearchResult(TypedDict, total=False):
text: Optional[str]
file_id: str
file_name: str
score: float
order: int


class WebSearchResult(TypedDict, total=False):
text: str
url: str
score: float


class DataSources(TypedDict, total=False):
file_search: Optional[List[FileSearchResult]]
web_search: Optional[List[WebSearchResult]]


class RunResponse(AI21BaseModel):
id: str
status: RunStatus
result: Any
data_sources: Optional[DataSources] = None
2 changes: 1 addition & 1 deletion tests/integration_tests/clients/studio/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _wait_for_file_to_process(client: AI21Client, file_id: str, timeout: float =
return

elapsed_time = time.time() - start_time
time.sleep(0.5)
time.sleep(2)

raise TimeoutError(f"Timeout: {timeout} seconds passed. File processing not completed")

Expand Down
16 changes: 16 additions & 0 deletions tests/integration_tests/clients/studio/test_maestro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from ai21 import AsyncAI21Client


@pytest.mark.asyncio
async def test_maestro__when_upload__should_return_data_sources(): # file_in_library: str):
client = AsyncAI21Client()
result = await client.beta.maestro.runs.create_and_poll(
input="When did Einstein receive a Nobel Prize?", tools=[{"type": "file_search"}], include=["data_sources"]
)
assert result.status == "completed", "Expected 'completed' status"
assert result.result, "Expected a non-empty answer"
assert result.data_sources, "Expected data sources"
assert len(result.data_sources["file_search"]) > 0, "Expected at least one file search data source"
assert result.data_sources.get("web_search") is None, "Expected no web search data sources"
Loading