Skip to content
Merged
37 changes: 33 additions & 4 deletions ai21/clients/common/beta/assistant/plans.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,62 @@
from __future__ import annotations

import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, Dict, Type, Callable, List

from pydantic import BaseModel

from ai21.errors import CodeParsingError
from ai21.models._pydantic_compatibility import _to_schema
from ai21.models.responses.plan_response import PlanResponse, ListPlanResponse
from ai21.types import NOT_GIVEN, NotGiven
from ai21.utils.typing import remove_not_given


class BasePlans(ABC):
_module_name = "plans"

def _parse_schema(self, schema: Type[BaseModel] | Dict[str, Any]) -> Dict[str, Any]:
if inspect.isclass(schema) and issubclass(schema, BaseModel):
return _to_schema(schema)
return schema

def _parse_code(self, code: str | Callable) -> str:
if callable(code):
try:
return inspect.getsource(code).strip()
except OSError as e:
raise CodeParsingError(str(e))
except Exception:
raise CodeParsingError()
return code

@abstractmethod
def create(
self,
*,
assistant_id: str,
code: str,
code: str | Callable,
schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN,
**kwargs,
) -> PlanResponse:
pass

def _create_body(
self,
*,
code: str,
code: str | Callable,
schemas: List[Dict[str, Any]] | List[BaseModel] | NotGiven = NOT_GIVEN,
**kwargs,
) -> Dict[str, Any]:
code_str = self._parse_code(code)

return remove_not_given(
{
"code": code,
"code": code_str,
"schemas": (
[self._parse_schema(schema) for schema in schemas] if schemas is not NOT_GIVEN else NOT_GIVEN
),
**kwargs,
}
)
Expand Down Expand Up @@ -57,5 +85,6 @@ def modify(
assistant_id: str,
plan_id: str,
code: str,
schemas: List[Dict[str, Any]] | NotGiven = NOT_GIVEN,
) -> PlanResponse:
pass
21 changes: 19 additions & 2 deletions ai21/clients/studio/resources/beta/assistant/assistants_plans.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

from typing import List, Any, Dict, Type

from pydantic import BaseModel

from ai21.clients.common.beta.assistant.plans import BasePlans
from ai21.clients.studio.resources.studio_resource import (
AsyncStudioResource,
StudioResource,
)
from ai21.models.responses.plan_response import PlanResponse, ListPlanResponse
from ai21.types import NotGiven, NOT_GIVEN


class AssistantPlans(StudioResource, BasePlans):
Expand All @@ -14,10 +19,12 @@ def create(
*,
assistant_id: str,
code: str,
schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN,
**kwargs,
) -> PlanResponse:
body = self._create_body(
code=code,
schemas=schemas,
**kwargs,
)

Expand All @@ -44,8 +51,12 @@ def modify(
assistant_id: str,
plan_id: str,
code: str,
schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN,
) -> PlanResponse:
body = dict(code=code)
body = self._create_body(
code=code,
schemas=schemas,
)

return self._patch(
path=f"/assistants/{assistant_id}/{self._module_name}/{plan_id}", body=body, response_cls=PlanResponse
Expand All @@ -58,10 +69,12 @@ async def create(
*,
assistant_id: str,
code: str,
schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN,
**kwargs,
) -> PlanResponse:
body = self._create_body(
code=code,
schemas=schemas,
**kwargs,
)

Expand Down Expand Up @@ -90,8 +103,12 @@ async def modify(
assistant_id: str,
plan_id: str,
code: str,
schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN,
) -> PlanResponse:
body = dict(code=code)
body = self._create_body(
code=code,
schemas=schemas,
)

return await self._patch(
path=f"/assistants/{assistant_id}/{self._module_name}/{plan_id}", body=body, response_cls=PlanResponse
Expand Down
7 changes: 7 additions & 0 deletions ai21/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,10 @@ def __init__(self, chunk: str, error_message: Optional[str] = None):
class InternalDependencyException(AI21APIError):
def __init__(self, details: Optional[str] = None):
super().__init__(530, details)


class CodeParsingError(AI21Error):
def __init__(self, details: Optional[str] = None):
message = f"Code can't be parsed{'' if details is None else f': {details}'}"
super().__init__(message)
self.message = message
9 changes: 8 additions & 1 deletion ai21/models/_pydantic_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Dict, Any
from typing import Dict, Any, Type

from pydantic import VERSION, BaseModel

Expand Down Expand Up @@ -33,3 +33,10 @@ def _from_json(obj: "AI21BaseModel", json_str: str, **kwargs) -> BaseModel: # n
return obj.model_validate_json(json_str, **kwargs)

return obj.parse_raw(json_str, **kwargs)


def _to_schema(model_object: Type[BaseModel], **kwargs) -> Dict[str, Any]:
if IS_PYDANTIC_V2:
return model_object.model_json_schema(**kwargs)

return model_object.schema(**kwargs)
3 changes: 2 additions & 1 deletion ai21/models/responses/plan_response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Dict, Any

from ai21.models.ai21_base_model import AI21BaseModel

Expand All @@ -10,6 +10,7 @@ class PlanResponse(AI21BaseModel):
updated_at: datetime
assistant_id: str
code: str
schemas: List[Dict[str, Any]]


class ListPlanResponse(AI21BaseModel):
Expand Down
25 changes: 25 additions & 0 deletions examples/studio/assistant/user_defined_plans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from ai21 import AI21Client
from pydantic import BaseModel

TIMEOUT = 20


def test_func():
pass


class ExampleSchema(BaseModel):
name: str
id: str


def main():
ai21_client = AI21Client()

assistant = ai21_client.beta.assistants.create(name="My Assistant")

plan = ai21_client.beta.assistants.plans.create(assistant_id=assistant.id, code=test_func, schemas=[ExampleSchema])
route = ai21_client.beta.assistants.routes.create(
assistant_id=assistant.id, plan_id=plan.id, name="My Route", examples=["hi"], description="My Route Description"
)
print(f"Route: {route}")
2 changes: 2 additions & 0 deletions tests/integration_tests/clients/test_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
("chat/chat_function_calling.py",),
("chat/chat_function_calling_multiple_tools.py",),
("chat/chat_response_format.py",),
("assistant/user_defined_plans.py",),
],
ids=[
"when_tokenization__should_return_ok",
Expand All @@ -35,6 +36,7 @@
"when_chat_completions_with_function_calling__should_return_ok",
"when_chat_completions_with_function_calling_multiple_tools_should_return_ok",
"when_chat_completions_with_response_format__should_return_ok",
"when_assistant_with_user_defined_plans_should_return_ok",
],
)
def test_studio(test_file_name: str):
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import Callable, List, Dict, Any, Type, Union

from pydantic import BaseModel
from ai21.clients.common.beta.assistant.plans import BasePlans
from ai21.errors import CodeParsingError
from ai21.models.responses.plan_response import PlanResponse, ListPlanResponse
from ai21.types import NotGiven, NOT_GIVEN
import pytest


class PlanTestClass(BasePlans):
def create(
self,
*,
assistant_id: str,
code: Union[str, Callable],
schemas: Union[List[Dict[str, Any]], List[Type[BaseModel]], NotGiven] = NOT_GIVEN,
**kwargs,
) -> PlanResponse:
pass

def list(self, *, assistant_id: str) -> ListPlanResponse:
pass

def retrieve(self, *, assistant_id: str, plan_id: str) -> PlanResponse:
pass

def modify(
self, *, assistant_id: str, plan_id: str, code: str, schemas: Union[List[Dict[str, Any]], NotGiven] = NOT_GIVEN
) -> PlanResponse:
pass


def test_create_body__when_pass_code_str__should_return_dict():
# Arrange
code = "code"

# Act
result = PlanTestClass()._create_body(code=code)

# Assert
assert result == {"code": code}


def test_create_body__when_pass_code_callable__should_return_dict():
# Arrange
def code():
return "code"

# Act
result = PlanTestClass()._create_body(code=code)

# Assert
assert result == {"code": 'def code():\n return "code"'}


def test_create_body__when_pass_code_and_dict_schemas__should_return_dict_with_schemas():
# Arrange
code = "code"
schemas = [{"type": "object", "properties": {"name": {"type": "string"}}}]

# Act
result = PlanTestClass()._create_body(code=code, schemas=schemas)

# Assert
assert result == {"code": code, "schemas": schemas}


class TestSchema(BaseModel):
name: str
age: int


def test_create_body__when_pass_code_and_pydantic_schemas__should_return_dict_with_converted_schemas():
# Arrange
code = "code"
schemas = [TestSchema]

# Act
result = PlanTestClass()._create_body(code=code, schemas=schemas)

# Assert
expected_schema = {
"properties": {"age": {"title": "Age", "type": "integer"}, "name": {"title": "Name", "type": "string"}},
"required": ["name", "age"],
"title": "TestSchema",
"type": "object",
}
assert result == {"code": code, "schemas": [expected_schema]}


def test_create_body__when_pass_code_and_not_given_schemas__should_return_dict_without_schemas():
# Arrange
code = "code"

# Act
result = PlanTestClass()._create_body(code=code, schemas=NOT_GIVEN)

# Assert
assert result == {"code": code}


def test_create_body__when_pass_empty_schemas_list__should_return_dict_with_empty_schemas():
# Arrange
code = "code"
schemas = []

# Act
result = PlanTestClass()._create_body(code=code, schemas=schemas)

# Assert
assert result == {"code": code, "schemas": schemas}


def test_create_body__when_cannot_get_source_code__should_raise_code_parsing_error():
# Arrange
class CallableWithoutSource:
def __call__(self):
return "result"

# Override __code__ to simulate a built-in function or method
@property
def __code__(self):
raise AttributeError("'CallableWithoutSource' object has no attribute '__code__'")

code = CallableWithoutSource()

# Act & Assert
with pytest.raises(CodeParsingError):
PlanTestClass()._create_body(code=code)
Loading