diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 36380a76..7507eef4 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -24,10 +24,12 @@ import json import logging +from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier from aixplain.modules.agent import Agent, Tool from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.agent.tool.pipeline_tool import PipelineTool +from aixplain.modules.pipeline import Pipeline from aixplain.utils import config from typing import Dict, List, Optional, Text, Union @@ -113,6 +115,27 @@ def create( raise Exception(e) return agent + @classmethod + def create_model_tool(cls, function: Union[Function, Text], supplier: Optional[Union[Supplier, Text]] = None) -> ModelTool: + """Create a new model tool.""" + if isinstance(function, str): + function = Function(function) + + if supplier is not None: + if isinstance(supplier, str): + for supplier_ in Supplier: + if supplier.lower() in [supplier.value["code"].lower(), supplier.value["name"].lower()]: + supplier = supplier_ + break + if isinstance(supplier, str): + supplier = None + return ModelTool(function=function, supplier=supplier) + + @classmethod + def create_pipeline_tool(cls, description: Text, pipeline: Union[Pipeline, Text]) -> PipelineTool: + """Create a new pipeline tool.""" + return PipelineTool(description=description, pipeline=pipeline) + @classmethod def list(cls) -> Dict: """List all agents available in the platform.""" diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index f58dcb63..766ba386 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -20,7 +20,6 @@ load_dotenv() from aixplain.factories import AgentFactory -from aixplain.modules.agent import ModelTool, PipelineTool from aixplain.enums.supplier import Supplier import pytest @@ -48,10 +47,10 @@ def test_end2end(run_input_map): ]: tool["supplier"] = supplier break - tools.append(ModelTool(function=tool["function"], supplier=tool["supplier"])) + tools.append(AgentFactory.create_model_tool(function=tool["function"], supplier=tool["supplier"])) if "pipeline_tools" in run_input_map: for tool in run_input_map["pipeline_tools"]: - tools.append(PipelineTool(description=tool["description"], pipeline=tool["pipeline_id"])) + tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"])) print(f"Creating agent with tools: {tools}") agent = AgentFactory.create(name=run_input_map["agent_name"], llm_id=run_input_map["llm_id"], tools=tools) print(f"Agent created: {agent.__dict__}")