diff --git a/ai21/clients/common/beta/assistant/plans.py b/ai21/clients/common/beta/assistant/plans.py index e4248f68..6923cf75 100644 --- a/ai21/clients/common/beta/assistant/plans.py +++ b/ai21/clients/common/beta/assistant/plans.py @@ -1,21 +1,43 @@ from __future__ import annotations +import inspect from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Type, Callable, List +from pydantic import BaseModel + +from ai21.errors import CodeParsingError +from ai21.models._pydantic_compatibility import _to_schema from ai21.models.responses.plan_response import PlanResponse, ListPlanResponse +from ai21.types import NOT_GIVEN, NotGiven from ai21.utils.typing import remove_not_given class BasePlans(ABC): _module_name = "plans" + def _parse_schema(self, schema: Type[BaseModel] | Dict[str, Any]) -> Dict[str, Any]: + if inspect.isclass(schema) and issubclass(schema, BaseModel): + return _to_schema(schema) + return schema + + def _parse_code(self, code: str | Callable) -> str: + if callable(code): + try: + return inspect.getsource(code).strip() + except OSError as e: + raise CodeParsingError(str(e)) + except Exception: + raise CodeParsingError() + return code + @abstractmethod def create( self, *, assistant_id: str, - code: str, + code: str | Callable, + schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN, **kwargs, ) -> PlanResponse: pass @@ -23,12 +45,18 @@ def create( def _create_body( self, *, - code: str, + code: str | Callable, + schemas: List[Dict[str, Any]] | List[BaseModel] | NotGiven = NOT_GIVEN, **kwargs, ) -> Dict[str, Any]: + code_str = self._parse_code(code) + return remove_not_given( { - "code": code, + "code": code_str, + "schemas": ( + [self._parse_schema(schema) for schema in schemas] if schemas is not NOT_GIVEN else NOT_GIVEN + ), **kwargs, } ) @@ -57,5 +85,6 @@ def modify( assistant_id: str, plan_id: str, code: str, + schemas: List[Dict[str, Any]] | NotGiven = NOT_GIVEN, ) -> PlanResponse: pass diff --git a/ai21/clients/studio/resources/beta/assistant/assistants_plans.py b/ai21/clients/studio/resources/beta/assistant/assistants_plans.py index aa6a66d4..81d78eee 100644 --- a/ai21/clients/studio/resources/beta/assistant/assistants_plans.py +++ b/ai21/clients/studio/resources/beta/assistant/assistants_plans.py @@ -1,11 +1,16 @@ from __future__ import annotations +from typing import List, Any, Dict, Type + +from pydantic import BaseModel + from ai21.clients.common.beta.assistant.plans import BasePlans from ai21.clients.studio.resources.studio_resource import ( AsyncStudioResource, StudioResource, ) from ai21.models.responses.plan_response import PlanResponse, ListPlanResponse +from ai21.types import NotGiven, NOT_GIVEN class AssistantPlans(StudioResource, BasePlans): @@ -14,10 +19,12 @@ def create( *, assistant_id: str, code: str, + schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN, **kwargs, ) -> PlanResponse: body = self._create_body( code=code, + schemas=schemas, **kwargs, ) @@ -44,8 +51,12 @@ def modify( assistant_id: str, plan_id: str, code: str, + schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN, ) -> PlanResponse: - body = dict(code=code) + body = self._create_body( + code=code, + schemas=schemas, + ) return self._patch( path=f"/assistants/{assistant_id}/{self._module_name}/{plan_id}", body=body, response_cls=PlanResponse @@ -58,10 +69,12 @@ async def create( *, assistant_id: str, code: str, + schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN, **kwargs, ) -> PlanResponse: body = self._create_body( code=code, + schemas=schemas, **kwargs, ) @@ -90,8 +103,12 @@ async def modify( assistant_id: str, plan_id: str, code: str, + schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN, ) -> PlanResponse: - body = dict(code=code) + body = self._create_body( + code=code, + schemas=schemas, + ) return await self._patch( path=f"/assistants/{assistant_id}/{self._module_name}/{plan_id}", body=body, response_cls=PlanResponse diff --git a/ai21/errors.py b/ai21/errors.py index 5e557998..7b2f168f 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -111,3 +111,10 @@ def __init__(self, chunk: str, error_message: Optional[str] = None): class InternalDependencyException(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(530, details) + + +class CodeParsingError(AI21Error): + def __init__(self, details: Optional[str] = None): + message = f"Code can't be parsed{'' if details is None else f': {details}'}" + super().__init__(message) + self.message = message diff --git a/ai21/models/_pydantic_compatibility.py b/ai21/models/_pydantic_compatibility.py index 5f58c0de..41d08509 100644 --- a/ai21/models/_pydantic_compatibility.py +++ b/ai21/models/_pydantic_compatibility.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Any +from typing import Dict, Any, Type from pydantic import VERSION, BaseModel @@ -33,3 +33,10 @@ def _from_json(obj: "AI21BaseModel", json_str: str, **kwargs) -> BaseModel: # n return obj.model_validate_json(json_str, **kwargs) return obj.parse_raw(json_str, **kwargs) + + +def _to_schema(model_object: Type[BaseModel], **kwargs) -> Dict[str, Any]: + if IS_PYDANTIC_V2: + return model_object.model_json_schema(**kwargs) + + return model_object.schema(**kwargs) diff --git a/ai21/models/responses/plan_response.py b/ai21/models/responses/plan_response.py index 3f87152d..db0876a3 100644 --- a/ai21/models/responses/plan_response.py +++ b/ai21/models/responses/plan_response.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Dict, Any from ai21.models.ai21_base_model import AI21BaseModel @@ -10,6 +10,7 @@ class PlanResponse(AI21BaseModel): updated_at: datetime assistant_id: str code: str + schemas: List[Dict[str, Any]] class ListPlanResponse(AI21BaseModel): diff --git a/examples/studio/assistant/user_defined_plans.py b/examples/studio/assistant/user_defined_plans.py new file mode 100644 index 00000000..ba8e1f54 --- /dev/null +++ b/examples/studio/assistant/user_defined_plans.py @@ -0,0 +1,25 @@ +from ai21 import AI21Client +from pydantic import BaseModel + +TIMEOUT = 20 + + +def test_func(): + pass + + +class ExampleSchema(BaseModel): + name: str + id: str + + +def main(): + ai21_client = AI21Client() + + assistant = ai21_client.beta.assistants.create(name="My Assistant") + + plan = ai21_client.beta.assistants.plans.create(assistant_id=assistant.id, code=test_func, schemas=[ExampleSchema]) + route = ai21_client.beta.assistants.routes.create( + assistant_id=assistant.id, plan_id=plan.id, name="My Route", examples=["hi"], description="My Route Description" + ) + print(f"Route: {route}") diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index 4a581b0e..24c75a7f 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -25,6 +25,7 @@ ("chat/chat_function_calling.py",), ("chat/chat_function_calling_multiple_tools.py",), ("chat/chat_response_format.py",), + ("assistant/user_defined_plans.py",), ], ids=[ "when_tokenization__should_return_ok", @@ -35,6 +36,7 @@ "when_chat_completions_with_function_calling__should_return_ok", "when_chat_completions_with_function_calling_multiple_tools_should_return_ok", "when_chat_completions_with_response_format__should_return_ok", + "when_assistant_with_user_defined_plans_should_return_ok", ], ) def test_studio(test_file_name: str): diff --git a/tests/unittests/clients/studio/resources/assistant/__init__.py b/tests/unittests/clients/studio/resources/assistant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/resources/assistant/plans/__init__.py b/tests/unittests/clients/studio/resources/assistant/plans/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/resources/assistant/plans/test_plan_body_creation.py b/tests/unittests/clients/studio/resources/assistant/plans/test_plan_body_creation.py new file mode 100644 index 00000000..bfd202c2 --- /dev/null +++ b/tests/unittests/clients/studio/resources/assistant/plans/test_plan_body_creation.py @@ -0,0 +1,130 @@ +from typing import Callable, List, Dict, Any, Type, Union + +from pydantic import BaseModel +from ai21.clients.common.beta.assistant.plans import BasePlans +from ai21.errors import CodeParsingError +from ai21.models.responses.plan_response import PlanResponse, ListPlanResponse +from ai21.types import NotGiven, NOT_GIVEN +import pytest + + +class PlanTestClass(BasePlans): + def create( + self, + *, + assistant_id: str, + code: Union[str, Callable], + schemas: Union[List[Dict[str, Any]], List[Type[BaseModel]], NotGiven] = NOT_GIVEN, + **kwargs, + ) -> PlanResponse: + pass + + def list(self, *, assistant_id: str) -> ListPlanResponse: + pass + + def retrieve(self, *, assistant_id: str, plan_id: str) -> PlanResponse: + pass + + def modify( + self, *, assistant_id: str, plan_id: str, code: str, schemas: Union[List[Dict[str, Any]], NotGiven] = NOT_GIVEN + ) -> PlanResponse: + pass + + +def test_create_body__when_pass_code_str__should_return_dict(): + # Arrange + code = "code" + + # Act + result = PlanTestClass()._create_body(code=code) + + # Assert + assert result == {"code": code} + + +def test_create_body__when_pass_code_callable__should_return_dict(): + # Arrange + def code(): + return "code" + + # Act + result = PlanTestClass()._create_body(code=code) + + # Assert + assert result == {"code": 'def code():\n return "code"'} + + +def test_create_body__when_pass_code_and_dict_schemas__should_return_dict_with_schemas(): + # Arrange + code = "code" + schemas = [{"type": "object", "properties": {"name": {"type": "string"}}}] + + # Act + result = PlanTestClass()._create_body(code=code, schemas=schemas) + + # Assert + assert result == {"code": code, "schemas": schemas} + + +class TestSchema(BaseModel): + name: str + age: int + + +def test_create_body__when_pass_code_and_pydantic_schemas__should_return_dict_with_converted_schemas(): + # Arrange + code = "code" + schemas = [TestSchema] + + # Act + result = PlanTestClass()._create_body(code=code, schemas=schemas) + + # Assert + expected_schema = { + "properties": {"age": {"title": "Age", "type": "integer"}, "name": {"title": "Name", "type": "string"}}, + "required": ["name", "age"], + "title": "TestSchema", + "type": "object", + } + assert result == {"code": code, "schemas": [expected_schema]} + + +def test_create_body__when_pass_code_and_not_given_schemas__should_return_dict_without_schemas(): + # Arrange + code = "code" + + # Act + result = PlanTestClass()._create_body(code=code, schemas=NOT_GIVEN) + + # Assert + assert result == {"code": code} + + +def test_create_body__when_pass_empty_schemas_list__should_return_dict_with_empty_schemas(): + # Arrange + code = "code" + schemas = [] + + # Act + result = PlanTestClass()._create_body(code=code, schemas=schemas) + + # Assert + assert result == {"code": code, "schemas": schemas} + + +def test_create_body__when_cannot_get_source_code__should_raise_code_parsing_error(): + # Arrange + class CallableWithoutSource: + def __call__(self): + return "result" + + # Override __code__ to simulate a built-in function or method + @property + def __code__(self): + raise AttributeError("'CallableWithoutSource' object has no attribute '__code__'") + + code = CallableWithoutSource() + + # Act & Assert + with pytest.raises(CodeParsingError): + PlanTestClass()._create_body(code=code)