Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6efbca7
feat: :sparkles: introduce Maestro
Mar 11, 2025
f72286e
ci: Added remove cache
Josephasafg Mar 11, 2025
a267dbc
chore: :refactor: use `ChatMessage` instead of `Message` pretty Typed…
Mar 11, 2025
548d81e
chore: :refactor: use `ChatMessage` instead of `Message` pretty Typed…
Mar 11, 2025
e76bba6
chore: :truck: rename `maestro/runs` files to `maestro/run`
Mar 11, 2025
c885a72
fix: :recycle: bug fixes and support for more params
Mar 11, 2025
ca3424a
fix: :fire: remove unsupported parameters
Mar 12, 2025
74bc78e
fix: :fire: remove unsupported parameters
Mar 12, 2025
a1c2f0a
fix: :bug: try saving the day
Mar 12, 2025
8879087
fix: :bug: try saving the day
Mar 12, 2025
50b09aa
fix: :bug: try saving the day
Mar 12, 2025
8ece7db
Merge branch 'main' into EXEC-866-maestro
Mar 12, 2025
87953d1
fix: :bug: let's give it another go shall we
Mar 12, 2025
b0104f0
fix: :fire: remove unused functions
Mar 13, 2025
12151a6
test: :white_check_mark: tests
Mar 13, 2025
495ce5b
refactor: :truck: rename messages and constraints
Mar 16, 2025
02efd99
ci: :technologist: add git hooks to check for in commit content
Mar 16, 2025
2e7f9f3
chore: :truck: move `maestro` under beta
Mar 16, 2025
ce25f4d
Merge branch 'main' into EXEC-866-maestro
benshuk Mar 16, 2025
df84031
docs: :memo: update examples
Mar 18, 2025
c124fa0
docs: :memo: update README
Mar 19, 2025
ccacd97
refactor: :truck: rename maestro runs examples
Mar 19, 2025
f3604fd
chore: :wrench: add budget support
Mar 19, 2025
4d8c5ea
Merge branch 'main' into EXEC-866-maestro
benshuk Mar 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .git-hooks/check_api_key.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash

# Check for `api_key=` in staged changes
if git diff --cached | grep -q "api_key="; then
echo "❌ Commit blocked: Found 'api_key=' in staged changes."
exit 1 # Prevent commit
fi

exit 0 # Allow commit
1 change: 1 addition & 0 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ jobs:
- name: Set Poetry environment
run: |
poetry env use ${{ matrix.python-version }}
poetry cache clear --all pypi
- name: Install dependencies
run: |
poetry install --all-extras
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ jobs:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: poetry.lock
- name: Set Poetry environment
- name: Set Poetry environment and clear cache
run: |
poetry env use ${{ matrix.python-version }}
poetry cache clear --all pypi
- name: Install dependencies
run: |
poetry install --no-root --only dev --all-extras
Expand Down Expand Up @@ -58,9 +59,10 @@ jobs:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: poetry.lock
- name: Set Poetry environment
- name: Set Poetry environment and clear cache
run: |
poetry env use ${{ matrix.python-version }}
poetry cache clear --all pypi
- name: Override Pydantic version
run: |
if [[ "${{ matrix.pydantic-version }}" == ^1.* ]]; then
Expand Down
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,10 @@ repos:
entry: hadolint/hadolint:v2.10.0 hadolint
types:
- dockerfile
- repo: local
hooks:
- id: check-api-key
name: Check for API keys
entry: .git-hooks/check_api_key.sh
language: system
stages: [commit]
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,34 @@ asyncio.run(main())

---

### Maestro

AI Planning & Orchestration System built for the enterprise. Read more [here](https://www.ai21.com/maestro/).

```python
from ai21 import AI21Client

client = AI21Client()

run_result = client.beta.maestro.runs.create_and_poll(
input="Write a poem about the ocean",
requirements=[
{
"name": "length requirement",
"description": "The length of the poem should be less than 1000 characters",
},
{
"name": "rhyme requirement",
"description": "The poem should rhyme",
},
],
)
```

For a more detailed example, see maestro [sync](examples/studio/maestro/run.py) and [async](examples/studio/maestro/async_run.py) examples.

---

### Conversational RAG (Beta)

Like chat, but with the ability to retrieve information from your Studio library.
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions ai21/clients/common/maestro/maestro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from abc import ABC

from ai21.clients.common.maestro.run import BaseMaestroRun


class BaseMaestro(ABC):
_module_name = "maestro"
runs: BaseMaestroRun
86 changes: 86 additions & 0 deletions ai21/clients/common/maestro/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Dict, Any

from ai21.models.chat import ChatMessage
from ai21.models.maestro.run import (
Tool,
ToolResources,
RunResponse,
DEFAULT_RUN_POLL_INTERVAL,
DEFAULT_RUN_POLL_TIMEOUT,
Requirement,
Budget,
)
from ai21.types import NOT_GIVEN, NotGiven
from ai21.utils.typing import remove_not_given


class BaseMaestroRun(ABC):
_module_name = "maestro/runs"

def _create_body(
self,
*,
input: str | List[ChatMessage],
models: List[str] | NotGiven,
tools: List[Tool] | NotGiven,
tool_resources: ToolResources | NotGiven,
context: Dict[str, Any] | NotGiven,
requirements: List[Requirement] | NotGiven,
budget: Budget | NotGiven,
**kwargs,
) -> dict:
return remove_not_given(
{
"input": input,
"models": models,
"tools": tools,
"tool_resources": tool_resources,
"context": context,
"requirements": requirements,
"budget": budget,
**kwargs,
}
)

@abstractmethod
def create(
self,
*,
input: str | List[ChatMessage],
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
**kwargs,
) -> RunResponse:
pass

@abstractmethod
def retrieve(self, run_id: str) -> RunResponse:
pass

@abstractmethod
def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse:
pass

@abstractmethod
def create_and_poll(
self,
*,
input: str | List[ChatMessage],
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
**kwargs,
) -> RunResponse:
pass
2 changes: 2 additions & 0 deletions ai21/clients/studio/resources/beta/async_beta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ai21.clients.studio.resources.maestro.maestro import AsyncMaestro
from ai21.clients.studio.resources.studio_conversational_rag import AsyncStudioConversationalRag
from ai21.clients.studio.resources.studio_resource import AsyncStudioResource
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
Expand All @@ -8,3 +9,4 @@ def __init__(self, client: AsyncAI21HTTPClient):
super().__init__(client)

self.conversational_rag = AsyncStudioConversationalRag(client)
self.maestro = AsyncMaestro(client)
2 changes: 2 additions & 0 deletions ai21/clients/studio/resources/beta/beta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ai21.clients.studio.resources.maestro.maestro import Maestro
from ai21.clients.studio.resources.studio_conversational_rag import StudioConversationalRag
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.http_client.http_client import AI21HTTPClient
Expand All @@ -8,3 +9,4 @@ def __init__(self, client: AI21HTTPClient):
super().__init__(client)

self.conversational_rag = StudioConversationalRag(client)
self.maestro = Maestro(client)
Empty file.
19 changes: 19 additions & 0 deletions ai21/clients/studio/resources/maestro/maestro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from ai21.clients.common.maestro.maestro import BaseMaestro
from ai21.clients.studio.resources.maestro.run import MaestroRun, AsyncMaestroRun
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient


class Maestro(StudioResource, BaseMaestro):
def __init__(self, client: AI21HTTPClient):
super().__init__(client)

self.runs = MaestroRun(client)


class AsyncMaestro(AsyncStudioResource, BaseMaestro):
def __init__(self, client: AsyncAI21HTTPClient):
super().__init__(client)

self.runs = AsyncMaestroRun(client)
170 changes: 170 additions & 0 deletions ai21/clients/studio/resources/maestro/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import annotations

import asyncio
import time
from typing import Any, List, Dict

from ai21.clients.common.maestro.run import BaseMaestroRun
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
from ai21.models.chat import ChatMessage
from ai21.models.maestro.run import (
Tool,
ToolResources,
RunResponse,
TERMINATED_RUN_STATUSES,
DEFAULT_RUN_POLL_INTERVAL,
DEFAULT_RUN_POLL_TIMEOUT,
Requirement,
Budget,
)
from ai21.types import NotGiven, NOT_GIVEN


class MaestroRun(StudioResource, BaseMaestroRun):
def create(
self,
*,
input: str | List[ChatMessage],
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
**kwargs,
) -> RunResponse:
body = self._create_body(
input=input,
models=models,
tools=tools,
tool_resources=tool_resources,
context=context,
requirements=requirements,
budget=budget,
**kwargs,
)

return self._post(path=f"/{self._module_name}", body=body, response_cls=RunResponse)

def retrieve(
self,
run_id: str,
) -> RunResponse:
return self._get(path=f"/{self._module_name}/{run_id}", response_cls=RunResponse)

def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse:
start_time = time.time()

while True:
run = self.retrieve(run_id)

if run.status in TERMINATED_RUN_STATUSES:
return run

if (time.time() - start_time) >= poll_timeout:
return run

time.sleep(poll_interval)

def create_and_poll(
self,
*,
input: str | List[ChatMessage],
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
**kwargs,
) -> RunResponse:
run = self.create(
input=input,
models=models,
tools=tools,
tool_resources=tool_resources,
context=context,
requirements=requirements,
budget=budget,
**kwargs,
)

return self._poll_for_status(run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec)


class AsyncMaestroRun(AsyncStudioResource, BaseMaestroRun):
async def create(
self,
*,
input: str | List[ChatMessage],
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
**kwargs,
) -> RunResponse:
body = self._create_body(
input=input,
models=models,
tools=tools,
tool_resources=tool_resources,
context=context,
requirements=requirements,
budget=budget,
**kwargs,
)

return await self._post(path=f"/{self._module_name}", body=body, response_cls=RunResponse)

async def retrieve(
self,
run_id: str,
) -> RunResponse:
return await self._get(path=f"/{self._module_name}/{run_id}", response_cls=RunResponse)

async def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse:
start_time = time.time()

while True:
run = await self.retrieve(run_id)

if run.status in TERMINATED_RUN_STATUSES:
return run

if (time.time() - start_time) >= poll_timeout:
return run

await asyncio.sleep(poll_interval)

async def create_and_poll(
self,
*,
input: str | List[ChatMessage],
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
budget: Budget | NotGiven = NOT_GIVEN,
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
**kwargs,
) -> RunResponse:
run = await self.create(
input=input,
models=models,
tools=tools,
tool_resources=tool_resources,
context=context,
requirements=requirements,
budget=budget,
**kwargs,
)

return await self._poll_for_status(
run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec
)
Loading
Loading