diff --git a/.git-hooks/check_api_key.sh b/.git-hooks/check_api_key.sh new file mode 100755 index 00000000..6bb06596 --- /dev/null +++ b/.git-hooks/check_api_key.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# Check for `api_key=` in staged changes +if git diff --cached | grep -q "api_key="; then + echo "❌ Commit blocked: Found 'api_key=' in staged changes." + exit 1 # Prevent commit +fi + +exit 0 # Allow commit diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 6e9c89de..3d71fb48 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -53,6 +53,7 @@ jobs: - name: Set Poetry environment run: | poetry env use ${{ matrix.python-version }} + poetry cache clear --all pypi - name: Install dependencies run: | poetry install --all-extras diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 3e22653f..8c4ffb26 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -25,9 +25,10 @@ jobs: python-version: ${{ matrix.python-version }} cache: poetry cache-dependency-path: poetry.lock - - name: Set Poetry environment + - name: Set Poetry environment and clear cache run: | poetry env use ${{ matrix.python-version }} + poetry cache clear --all pypi - name: Install dependencies run: | poetry install --no-root --only dev --all-extras @@ -58,9 +59,10 @@ jobs: python-version: ${{ matrix.python-version }} cache: poetry cache-dependency-path: poetry.lock - - name: Set Poetry environment + - name: Set Poetry environment and clear cache run: | poetry env use ${{ matrix.python-version }} + poetry cache clear --all pypi - name: Override Pydantic version run: | if [[ "${{ matrix.pydantic-version }}" == ^1.* ]]; then diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9c210e34..eece3b24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -118,3 +118,10 @@ repos: entry: hadolint/hadolint:v2.10.0 hadolint types: - dockerfile + - repo: local + hooks: + - id: check-api-key + name: Check for API keys + entry: .git-hooks/check_api_key.sh + language: system + stages: [commit] diff --git a/README.md b/README.md index 7f92e4bd..4276122b 100644 --- a/README.md +++ b/README.md @@ -248,6 +248,34 @@ asyncio.run(main()) --- +### Maestro + +AI Planning & Orchestration System built for the enterprise. Read more [here](https://www.ai21.com/maestro/). + +```python +from ai21 import AI21Client + +client = AI21Client() + +run_result = client.beta.maestro.runs.create_and_poll( + input="Write a poem about the ocean", + requirements=[ + { + "name": "length requirement", + "description": "The length of the poem should be less than 1000 characters", + }, + { + "name": "rhyme requirement", + "description": "The poem should rhyme", + }, + ], +) +``` + +For a more detailed example, see maestro [sync](examples/studio/maestro/run.py) and [async](examples/studio/maestro/async_run.py) examples. + +--- + ### Conversational RAG (Beta) Like chat, but with the ability to retrieve information from your Studio library. diff --git a/ai21/clients/common/maestro/__init__.py b/ai21/clients/common/maestro/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/clients/common/maestro/maestro.py b/ai21/clients/common/maestro/maestro.py new file mode 100644 index 00000000..393e39c9 --- /dev/null +++ b/ai21/clients/common/maestro/maestro.py @@ -0,0 +1,8 @@ +from abc import ABC + +from ai21.clients.common.maestro.run import BaseMaestroRun + + +class BaseMaestro(ABC): + _module_name = "maestro" + runs: BaseMaestroRun diff --git a/ai21/clients/common/maestro/run.py b/ai21/clients/common/maestro/run.py new file mode 100644 index 00000000..43851dec --- /dev/null +++ b/ai21/clients/common/maestro/run.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List, Dict, Any + +from ai21.models.chat import ChatMessage +from ai21.models.maestro.run import ( + Tool, + ToolResources, + RunResponse, + DEFAULT_RUN_POLL_INTERVAL, + DEFAULT_RUN_POLL_TIMEOUT, + Requirement, + Budget, +) +from ai21.types import NOT_GIVEN, NotGiven +from ai21.utils.typing import remove_not_given + + +class BaseMaestroRun(ABC): + _module_name = "maestro/runs" + + def _create_body( + self, + *, + input: str | List[ChatMessage], + models: List[str] | NotGiven, + tools: List[Tool] | NotGiven, + tool_resources: ToolResources | NotGiven, + context: Dict[str, Any] | NotGiven, + requirements: List[Requirement] | NotGiven, + budget: Budget | NotGiven, + **kwargs, + ) -> dict: + return remove_not_given( + { + "input": input, + "models": models, + "tools": tools, + "tool_resources": tool_resources, + "context": context, + "requirements": requirements, + "budget": budget, + **kwargs, + } + ) + + @abstractmethod + def create( + self, + *, + input: str | List[ChatMessage], + models: List[str] | NotGiven = NOT_GIVEN, + tools: List[Tool] | NotGiven = NOT_GIVEN, + tool_resources: ToolResources | NotGiven = NOT_GIVEN, + context: Dict[str, Any] | NotGiven = NOT_GIVEN, + requirements: List[Requirement] | NotGiven = NOT_GIVEN, + budget: Budget | NotGiven = NOT_GIVEN, + **kwargs, + ) -> RunResponse: + pass + + @abstractmethod + def retrieve(self, run_id: str) -> RunResponse: + pass + + @abstractmethod + def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse: + pass + + @abstractmethod + def create_and_poll( + self, + *, + input: str | List[ChatMessage], + models: List[str] | NotGiven = NOT_GIVEN, + tools: List[Tool] | NotGiven = NOT_GIVEN, + tool_resources: ToolResources | NotGiven = NOT_GIVEN, + context: Dict[str, Any] | NotGiven = NOT_GIVEN, + requirements: List[Requirement] | NotGiven = NOT_GIVEN, + budget: Budget | NotGiven = NOT_GIVEN, + poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL, + poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT, + **kwargs, + ) -> RunResponse: + pass diff --git a/ai21/clients/studio/resources/beta/async_beta.py b/ai21/clients/studio/resources/beta/async_beta.py index a94ddc95..8c5bb35a 100644 --- a/ai21/clients/studio/resources/beta/async_beta.py +++ b/ai21/clients/studio/resources/beta/async_beta.py @@ -1,3 +1,4 @@ +from ai21.clients.studio.resources.maestro.maestro import AsyncMaestro from ai21.clients.studio.resources.studio_conversational_rag import AsyncStudioConversationalRag from ai21.clients.studio.resources.studio_resource import AsyncStudioResource from ai21.http_client.async_http_client import AsyncAI21HTTPClient @@ -8,3 +9,4 @@ def __init__(self, client: AsyncAI21HTTPClient): super().__init__(client) self.conversational_rag = AsyncStudioConversationalRag(client) + self.maestro = AsyncMaestro(client) diff --git a/ai21/clients/studio/resources/beta/beta.py b/ai21/clients/studio/resources/beta/beta.py index 1269f970..c33d3bc0 100644 --- a/ai21/clients/studio/resources/beta/beta.py +++ b/ai21/clients/studio/resources/beta/beta.py @@ -1,3 +1,4 @@ +from ai21.clients.studio.resources.maestro.maestro import Maestro from ai21.clients.studio.resources.studio_conversational_rag import StudioConversationalRag from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.http_client.http_client import AI21HTTPClient @@ -8,3 +9,4 @@ def __init__(self, client: AI21HTTPClient): super().__init__(client) self.conversational_rag = StudioConversationalRag(client) + self.maestro = Maestro(client) diff --git a/ai21/clients/studio/resources/maestro/__init__.py b/ai21/clients/studio/resources/maestro/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/clients/studio/resources/maestro/maestro.py b/ai21/clients/studio/resources/maestro/maestro.py new file mode 100644 index 00000000..346b0cc0 --- /dev/null +++ b/ai21/clients/studio/resources/maestro/maestro.py @@ -0,0 +1,19 @@ +from ai21.clients.common.maestro.maestro import BaseMaestro +from ai21.clients.studio.resources.maestro.run import MaestroRun, AsyncMaestroRun +from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource +from ai21.http_client.async_http_client import AsyncAI21HTTPClient +from ai21.http_client.http_client import AI21HTTPClient + + +class Maestro(StudioResource, BaseMaestro): + def __init__(self, client: AI21HTTPClient): + super().__init__(client) + + self.runs = MaestroRun(client) + + +class AsyncMaestro(AsyncStudioResource, BaseMaestro): + def __init__(self, client: AsyncAI21HTTPClient): + super().__init__(client) + + self.runs = AsyncMaestroRun(client) diff --git a/ai21/clients/studio/resources/maestro/run.py b/ai21/clients/studio/resources/maestro/run.py new file mode 100644 index 00000000..5166b714 --- /dev/null +++ b/ai21/clients/studio/resources/maestro/run.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, List, Dict + +from ai21.clients.common.maestro.run import BaseMaestroRun +from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource +from ai21.models.chat import ChatMessage +from ai21.models.maestro.run import ( + Tool, + ToolResources, + RunResponse, + TERMINATED_RUN_STATUSES, + DEFAULT_RUN_POLL_INTERVAL, + DEFAULT_RUN_POLL_TIMEOUT, + Requirement, + Budget, +) +from ai21.types import NotGiven, NOT_GIVEN + + +class MaestroRun(StudioResource, BaseMaestroRun): + def create( + self, + *, + input: str | List[ChatMessage], + models: List[str] | NotGiven = NOT_GIVEN, + tools: List[Tool] | NotGiven = NOT_GIVEN, + tool_resources: ToolResources | NotGiven = NOT_GIVEN, + context: Dict[str, Any] | NotGiven = NOT_GIVEN, + requirements: List[Requirement] | NotGiven = NOT_GIVEN, + budget: Budget | NotGiven = NOT_GIVEN, + **kwargs, + ) -> RunResponse: + body = self._create_body( + input=input, + models=models, + tools=tools, + tool_resources=tool_resources, + context=context, + requirements=requirements, + budget=budget, + **kwargs, + ) + + return self._post(path=f"/{self._module_name}", body=body, response_cls=RunResponse) + + def retrieve( + self, + run_id: str, + ) -> RunResponse: + return self._get(path=f"/{self._module_name}/{run_id}", response_cls=RunResponse) + + def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse: + start_time = time.time() + + while True: + run = self.retrieve(run_id) + + if run.status in TERMINATED_RUN_STATUSES: + return run + + if (time.time() - start_time) >= poll_timeout: + return run + + time.sleep(poll_interval) + + def create_and_poll( + self, + *, + input: str | List[ChatMessage], + models: List[str] | NotGiven = NOT_GIVEN, + tools: List[Tool] | NotGiven = NOT_GIVEN, + tool_resources: ToolResources | NotGiven = NOT_GIVEN, + context: Dict[str, Any] | NotGiven = NOT_GIVEN, + requirements: List[Requirement] | NotGiven = NOT_GIVEN, + budget: Budget | NotGiven = NOT_GIVEN, + poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL, + poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT, + **kwargs, + ) -> RunResponse: + run = self.create( + input=input, + models=models, + tools=tools, + tool_resources=tool_resources, + context=context, + requirements=requirements, + budget=budget, + **kwargs, + ) + + return self._poll_for_status(run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec) + + +class AsyncMaestroRun(AsyncStudioResource, BaseMaestroRun): + async def create( + self, + *, + input: str | List[ChatMessage], + models: List[str] | NotGiven = NOT_GIVEN, + tools: List[Tool] | NotGiven = NOT_GIVEN, + tool_resources: ToolResources | NotGiven = NOT_GIVEN, + context: Dict[str, Any] | NotGiven = NOT_GIVEN, + requirements: List[Requirement] | NotGiven = NOT_GIVEN, + budget: Budget | NotGiven = NOT_GIVEN, + **kwargs, + ) -> RunResponse: + body = self._create_body( + input=input, + models=models, + tools=tools, + tool_resources=tool_resources, + context=context, + requirements=requirements, + budget=budget, + **kwargs, + ) + + return await self._post(path=f"/{self._module_name}", body=body, response_cls=RunResponse) + + async def retrieve( + self, + run_id: str, + ) -> RunResponse: + return await self._get(path=f"/{self._module_name}/{run_id}", response_cls=RunResponse) + + async def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse: + start_time = time.time() + + while True: + run = await self.retrieve(run_id) + + if run.status in TERMINATED_RUN_STATUSES: + return run + + if (time.time() - start_time) >= poll_timeout: + return run + + await asyncio.sleep(poll_interval) + + async def create_and_poll( + self, + *, + input: str | List[ChatMessage], + models: List[str] | NotGiven = NOT_GIVEN, + tools: List[Tool] | NotGiven = NOT_GIVEN, + tool_resources: ToolResources | NotGiven = NOT_GIVEN, + context: Dict[str, Any] | NotGiven = NOT_GIVEN, + requirements: List[Requirement] | NotGiven = NOT_GIVEN, + budget: Budget | NotGiven = NOT_GIVEN, + poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL, + poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT, + **kwargs, + ) -> RunResponse: + run = await self.create( + input=input, + models=models, + tools=tools, + tool_resources=tool_resources, + context=context, + requirements=requirements, + budget=budget, + **kwargs, + ) + + return await self._poll_for_status( + run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec + ) diff --git a/ai21/models/_pydantic_compatibility.py b/ai21/models/_pydantic_compatibility.py index 5f58c0de..41d08509 100644 --- a/ai21/models/_pydantic_compatibility.py +++ b/ai21/models/_pydantic_compatibility.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Any +from typing import Dict, Any, Type from pydantic import VERSION, BaseModel @@ -33,3 +33,10 @@ def _from_json(obj: "AI21BaseModel", json_str: str, **kwargs) -> BaseModel: # n return obj.model_validate_json(json_str, **kwargs) return obj.parse_raw(json_str, **kwargs) + + +def _to_schema(model_object: Type[BaseModel], **kwargs) -> Dict[str, Any]: + if IS_PYDANTIC_V2: + return model_object.model_json_schema(**kwargs) + + return model_object.schema(**kwargs) diff --git a/ai21/models/maestro/__init__.py b/ai21/models/maestro/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/models/maestro/run.py b/ai21/models/maestro/run.py new file mode 100644 index 00000000..110c90f1 --- /dev/null +++ b/ai21/models/maestro/run.py @@ -0,0 +1,51 @@ +from typing import TypedDict, Literal, List, Optional, Any, Set, Dict, Type, Union + +from pydantic import BaseModel + +from ai21.models.ai21_base_model import AI21BaseModel + +Budget = Literal["low", "medium", "high"] +Role = Literal["user", "assistant"] +RunStatus = Literal["completed", "failed", "in_progress", "requires_action"] +ToolType = Literal["file_search", "web_search"] +PrimitiveTypes = Union[Type[str], Type[int], Type[float], Type[bool]] +PrimitiveLists = Type[List[PrimitiveTypes]] +OutputType = Union[Type[BaseModel], PrimitiveTypes, Dict[str, Any]] + +DEFAULT_RUN_POLL_INTERVAL: float = 1 # seconds +DEFAULT_RUN_POLL_TIMEOUT: float = 120 # seconds +TERMINATED_RUN_STATUSES: Set[RunStatus] = {"completed", "failed", "requires_action"} + + +class Tool(TypedDict): + type: ToolType + + +class FileSearchToolResource(TypedDict, total=False): + retrieval_similarity_threshold: Optional[float] + labels: Optional[List[str]] + labels_filter_mode: Optional[Literal["AND", "OR"]] + labels_filter: Optional[dict] + file_ids: Optional[List[str]] + retrieval_strategy: Optional[str] + max_neighbors: Optional[int] + + +class WebSearchToolResource(TypedDict, total=False): + urls: Optional[List[str]] + + +class ToolResources(TypedDict, total=False): + file_search: Optional[FileSearchToolResource] + web_search: Optional[WebSearchToolResource] + + +class Requirement(TypedDict): + name: str + description: str + + +class RunResponse(AI21BaseModel): + id: str + status: RunStatus + result: Any diff --git a/examples/studio/maestro/__init__.py b/examples/studio/maestro/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/studio/maestro/async_run.py b/examples/studio/maestro/async_run.py new file mode 100644 index 00000000..f1f00180 --- /dev/null +++ b/examples/studio/maestro/async_run.py @@ -0,0 +1,27 @@ +import asyncio + +from ai21 import AsyncAI21Client + +client = AsyncAI21Client() + + +async def main(): + run_result = await client.beta.maestro.runs.create_and_poll( + input="Write a poem about the ocean", + requirements=[ + { + "name": "length requirement", + "description": "The length of the poem should be less than 1000 characters", + }, + { + "name": "rhyme requirement", + "description": "The poem should rhyme", + }, + ], + ) + + print(run_result) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/studio/maestro/run.py b/examples/studio/maestro/run.py new file mode 100644 index 00000000..72024f33 --- /dev/null +++ b/examples/studio/maestro/run.py @@ -0,0 +1,25 @@ +from ai21 import AI21Client + +client = AI21Client() + + +def main(): + run_result = client.beta.maestro.runs.create_and_poll( + input="Write a poem about the ocean", + requirements=[ + { + "name": "length requirement", + "description": "The length of the poem should be less than 1000 characters", + }, + { + "name": "rhyme requirement", + "description": "The poem should rhyme", + }, + ], + ) + + print(run_result) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index 1e10e10e..a23d6e07 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -56,12 +56,16 @@ def test_studio(test_file_name: str): ("chat/async_stream_chat_completions.py",), ("conversational_rag/conversational_rag.py",), ("conversational_rag/async_conversational_rag.py",), + ("maestro/run.py",), + ("maestro/async_run.py",), ], ids=[ "when_chat_completions__should_return_ok", "when_stream_chat_completions__should_return_ok", "when_conversational_rag__should_return_ok", "when_async_conversational_rag__should_return_ok", + "when_maestro_runs__should_return_ok", + "when_maestro_async_runs__should_return_ok", ], ) async def test_async_studio(test_file_name: str):