diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 0a543d77..36380a76 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -41,12 +41,12 @@ class AgentFactory: def create( cls, name: Text, + llm_id: Text, tools: List[Tool] = [], description: Text = "", api_key: Text = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, - llm_id: Optional[Text] = None, ) -> Agent: """Create a new agent in the platform.""" try: diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 014b14fc..6363a08e 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -14,12 +14,14 @@ def build_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Agent: for i, tool in enumerate(tools): if tool["type"] == "model": for supplier in Supplier: - if tool["supplier"] is not None and tool["supplier"].lower() in [supplier.value["code"].lower(), supplier.value["name"].lower()]: + if tool["supplier"] is not None and tool["supplier"].lower() in [ + supplier.value["code"].lower(), + supplier.value["name"].lower(), + ]: tool["supplier"] = supplier break tool = ModelTool( - description=tool["description"], function=Function(tool["function"]), supplier=tool["supplier"], version=tool["version"], diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 38c03fe7..69bf28d5 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -20,7 +20,7 @@ Description: Agentification Class """ -from typing import Text, Optional +from typing import Optional from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier @@ -31,14 +31,12 @@ class ModelTool(Tool): """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. Attributes: - description (Text): descriptiion of the tool function (Function): task that the tool performs supplier (Optional[Union[Dict, Text, Supplier, int]], optional): Preferred supplier to perform the task. Defaults to None. """ def __init__( self, - description: Text, function: Function, supplier: Optional[Supplier] = None, **additional_info, @@ -46,12 +44,10 @@ def __init__( """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. Args: - name (Text): name of the tool - description (Text): descriptiion of the tool function (Function): task that the tool performs supplier (Optional[Union[Dict, Text, Supplier, int]], optional): Preferred supplier to perform the task. Defaults to None. """ - super().__init__("", description, **additional_info) + super().__init__("", "", **additional_info) if isinstance(function, str): function = Function(function) self.function = function diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index 330d4c67..a517b198 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -31,6 +31,7 @@ class PipelineTool(Tool): Attributes: description (Text): descriptiion of the tool + pipeline (Union[Text, Pipeline]): pipeline """ def __init__( diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py new file mode 100644 index 00000000..f58dcb63 --- /dev/null +++ b/tests/functional/agent/agent_functional_test.py @@ -0,0 +1,75 @@ +__author__ = "lucaspavanelli" + +""" +Copyright 2022 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import json +from dotenv import load_dotenv + +load_dotenv() +from aixplain.factories import AgentFactory +from aixplain.modules.agent import ModelTool, PipelineTool +from aixplain.enums.supplier import Supplier + +import pytest + +RUN_FILE = "tests/functional/agent/data/agent_test_end2end.json" + + +def read_data(data_path): + return json.load(open(data_path, "r")) + + +@pytest.fixture(scope="module", params=read_data(RUN_FILE)) +def run_input_map(request): + return request.param + + +def test_end2end(run_input_map): + tools = [] + if "model_tools" in run_input_map: + for tool in run_input_map["model_tools"]: + for supplier in Supplier: + if tool["supplier"] is not None and tool["supplier"].lower() in [ + supplier.value["code"].lower(), + supplier.value["name"].lower(), + ]: + tool["supplier"] = supplier + break + tools.append(ModelTool(function=tool["function"], supplier=tool["supplier"])) + if "pipeline_tools" in run_input_map: + for tool in run_input_map["pipeline_tools"]: + tools.append(PipelineTool(description=tool["description"], pipeline=tool["pipeline_id"])) + print(f"Creating agent with tools: {tools}") + agent = AgentFactory.create(name=run_input_map["agent_name"], llm_id=run_input_map["llm_id"], tools=tools) + print(f"Agent created: {agent.__dict__}") + print("Running agent") + response = agent.run(query=run_input_map["query"]) + print(f"Agent response: {response}") + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "data" in response + assert response["data"]["session_id"] is not None + assert response["data"]["output"] is not None + print("Deleting agent") + agent.delete() + + +def test_list_agents(): + agents = AgentFactory.list() + assert "results" in agents + agents_result = agents["results"] + assert type(agents_result) is list diff --git a/tests/functional/agent/data/agent_test_end2end.json b/tests/functional/agent/data/agent_test_end2end.json new file mode 100644 index 00000000..147928fe --- /dev/null +++ b/tests/functional/agent/data/agent_test_end2end.json @@ -0,0 +1,14 @@ +[ + { + "agent_name": "[TEST] Translation agent", + "llm_id": "6626a3a8c8f1d089790cf5a2", + "llm_name": "Groq Llama 3 70B", + "query": "Who is the president of Brazil right now? Translate to pt", + "model_tools": [ + { + "function": "translation", + "supplier": "AWS" + } + ] + } +]