diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 7507eef4..6076eef6 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -29,6 +29,7 @@ from aixplain.modules.agent import Agent, Tool from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.agent.tool.pipeline_tool import PipelineTool +from aixplain.modules.model import Model from aixplain.modules.pipeline import Pipeline from aixplain.utils import config from typing import Dict, List, Optional, Text, Union @@ -66,11 +67,12 @@ def create( if isinstance(tool, ModelTool): tool_payload.append( { - "function": tool.function.value, + "function": tool.function.value if tool.function is not None else None, "type": "model", "description": tool.description, "supplier": tool.supplier.value["code"] if tool.supplier else None, "version": tool.version if tool.version else None, + "assetId": tool.model, } ) elif isinstance(tool, PipelineTool): @@ -116,9 +118,14 @@ def create( return agent @classmethod - def create_model_tool(cls, function: Union[Function, Text], supplier: Optional[Union[Supplier, Text]] = None) -> ModelTool: + def create_model_tool( + cls, + model: Optional[Union[Model, Text]] = None, + function: Optional[Union[Function, Text]] = None, + supplier: Optional[Union[Supplier, Text]] = None, + ) -> ModelTool: """Create a new model tool.""" - if isinstance(function, str): + if function is not None and isinstance(function, str): function = Function(function) if supplier is not None: @@ -129,7 +136,7 @@ def create_model_tool(cls, function: Union[Function, Text], supplier: Optional[U break if isinstance(supplier, str): supplier = None - return ModelTool(function=function, supplier=supplier) + return ModelTool(function=function, supplier=supplier, model=model) @classmethod def create_pipeline_tool(cls, description: Text, pipeline: Union[Pipeline, Text]) -> PipelineTool: diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 6363a08e..4b314ef7 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -22,9 +22,10 @@ def build_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Agent: break tool = ModelTool( - function=Function(tool["function"]), + function=Function(tool["function"]) if tool["function"] is not None else None, supplier=tool["supplier"], version=tool["version"], + model=tool["assetId"], ) elif tool["type"] == "pipeline": tool = PipelineTool(description=tool["description"], pipeline=tool["assetId"]) diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 69bf28d5..a5acab30 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -20,11 +20,13 @@ Description: Agentification Class """ -from typing import Optional +from typing import Optional, Union, Text from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier +from aixplain.factories.model_factory import ModelFactory from aixplain.modules.agent.tool import Tool +from aixplain.modules.model import Model class ModelTool(Tool): @@ -37,19 +39,25 @@ class ModelTool(Tool): def __init__( self, - function: Function, + function: Optional[Function] = None, supplier: Optional[Supplier] = None, + model: Optional[Union[Text, Model]] = None, **additional_info, ) -> None: """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. Args: - function (Function): task that the tool performs - supplier (Optional[Union[Dict, Text, Supplier, int]], optional): Preferred supplier to perform the task. Defaults to None. + function (Optional[Function], optional): task that the tool performs. Defaults to None. + supplier (Optional[Supplier], optional): Preferred supplier to perform the task. Defaults to None.. Defaults to None. + model (Optional[Union[Text, Model]], optional): Model function. Defaults to None. """ + assert ( + function is not None or model is not None + ), "Agent Creation Error: Either function or model must be provided when instantiating a tool." super().__init__("", "", **additional_info) - if isinstance(function, str): - function = Function(function) + if function is not None: + if isinstance(function, str): + function = Function(function) self.function = function try: @@ -57,4 +65,12 @@ def __init__( supplier = Supplier(supplier) except Exception: supplier = None + + if model is not None: + if isinstance(model, Text) is True: + model = ModelFactory.get(model) + if isinstance(model.supplier, Supplier): + supplier = model.supplier + model = model.id self.supplier = supplier + self.model = model diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 766ba386..427f62e5 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -47,7 +47,7 @@ def test_end2end(run_input_map): ]: tool["supplier"] = supplier break - tools.append(AgentFactory.create_model_tool(function=tool["function"], supplier=tool["supplier"])) + tools.append(AgentFactory.create_model_tool(**tool)) if "pipeline_tools" in run_input_map: for tool in run_input_map["pipeline_tools"]: tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"])) diff --git a/tests/functional/agent/data/agent_test_end2end.json b/tests/functional/agent/data/agent_test_end2end.json index 147928fe..94bfc94b 100644 --- a/tests/functional/agent/data/agent_test_end2end.json +++ b/tests/functional/agent/data/agent_test_end2end.json @@ -8,6 +8,11 @@ { "function": "translation", "supplier": "AWS" + }, + { + "model": "60ddefca8d38c51c58860108", + "function": null, + "supplier": null } ] }