Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions aixplain/factories/agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from aixplain.modules.agent import Agent, Tool
from aixplain.modules.agent.tool.model_tool import ModelTool
from aixplain.modules.agent.tool.pipeline_tool import PipelineTool
from aixplain.modules.model import Model
from aixplain.modules.pipeline import Pipeline
from aixplain.utils import config
from typing import Dict, List, Optional, Text, Union
Expand Down Expand Up @@ -66,11 +67,12 @@ def create(
if isinstance(tool, ModelTool):
tool_payload.append(
{
"function": tool.function.value,
"function": tool.function.value if tool.function is not None else None,
"type": "model",
"description": tool.description,
"supplier": tool.supplier.value["code"] if tool.supplier else None,
"version": tool.version if tool.version else None,
"assetId": tool.model,
}
)
elif isinstance(tool, PipelineTool):
Expand Down Expand Up @@ -116,9 +118,14 @@ def create(
return agent

@classmethod
def create_model_tool(cls, function: Union[Function, Text], supplier: Optional[Union[Supplier, Text]] = None) -> ModelTool:
def create_model_tool(
cls,
model: Optional[Union[Model, Text]] = None,
function: Optional[Union[Function, Text]] = None,
supplier: Optional[Union[Supplier, Text]] = None,
) -> ModelTool:
"""Create a new model tool."""
if isinstance(function, str):
if function is not None and isinstance(function, str):
function = Function(function)

if supplier is not None:
Expand All @@ -129,7 +136,7 @@ def create_model_tool(cls, function: Union[Function, Text], supplier: Optional[U
break
if isinstance(supplier, str):
supplier = None
return ModelTool(function=function, supplier=supplier)
return ModelTool(function=function, supplier=supplier, model=model)

@classmethod
def create_pipeline_tool(cls, description: Text, pipeline: Union[Pipeline, Text]) -> PipelineTool:
Expand Down
3 changes: 2 additions & 1 deletion aixplain/factories/agent_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def build_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Agent:
break

tool = ModelTool(
function=Function(tool["function"]),
function=Function(tool["function"]) if tool["function"] is not None else None,
supplier=tool["supplier"],
version=tool["version"],
model=tool["assetId"],
)
elif tool["type"] == "pipeline":
tool = PipelineTool(description=tool["description"], pipeline=tool["assetId"])
Expand Down
28 changes: 22 additions & 6 deletions aixplain/modules/agent/tool/model_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
Description:
Agentification Class
"""
from typing import Optional
from typing import Optional, Union, Text

from aixplain.enums.function import Function
from aixplain.enums.supplier import Supplier
from aixplain.factories.model_factory import ModelFactory
from aixplain.modules.agent.tool import Tool
from aixplain.modules.model import Model


class ModelTool(Tool):
Expand All @@ -37,24 +39,38 @@ class ModelTool(Tool):

def __init__(
self,
function: Function,
function: Optional[Function] = None,
supplier: Optional[Supplier] = None,
model: Optional[Union[Text, Model]] = None,
**additional_info,
) -> None:
"""Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands.

Args:
function (Function): task that the tool performs
supplier (Optional[Union[Dict, Text, Supplier, int]], optional): Preferred supplier to perform the task. Defaults to None.
function (Optional[Function], optional): task that the tool performs. Defaults to None.
supplier (Optional[Supplier], optional): Preferred supplier to perform the task. Defaults to None.. Defaults to None.
model (Optional[Union[Text, Model]], optional): Model function. Defaults to None.
"""
assert (
function is not None or model is not None
), "Agent Creation Error: Either function or model must be provided when instantiating a tool."
super().__init__("", "", **additional_info)
if isinstance(function, str):
function = Function(function)
if function is not None:
if isinstance(function, str):
function = Function(function)
self.function = function

try:
if isinstance(supplier, dict):
supplier = Supplier(supplier)
except Exception:
supplier = None

if model is not None:
if isinstance(model, Text) is True:
model = ModelFactory.get(model)
if isinstance(model.supplier, Supplier):
supplier = model.supplier
model = model.id
self.supplier = supplier
self.model = model
2 changes: 1 addition & 1 deletion tests/functional/agent/agent_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_end2end(run_input_map):
]:
tool["supplier"] = supplier
break
tools.append(AgentFactory.create_model_tool(function=tool["function"], supplier=tool["supplier"]))
tools.append(AgentFactory.create_model_tool(**tool))
if "pipeline_tools" in run_input_map:
for tool in run_input_map["pipeline_tools"]:
tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"]))
Expand Down
5 changes: 5 additions & 0 deletions tests/functional/agent/data/agent_test_end2end.json
Copy link
Contributor

@ikxplain ikxplain Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thiago-aixplain here you are using a fixed model id, is this model public and shared through all environments?

Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
{
"function": "translation",
"supplier": "AWS"
},
{
"model": "60ddefca8d38c51c58860108",
"function": null,
"supplier": null
}
]
}
Expand Down