diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6b06079..456aba3b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,5 +21,5 @@ repos: rev: v2.0.0 # Use the latest version hooks: - id: flake8 - args: # arguments to configure black - - --ignore=E402,E501 \ No newline at end of file + args: # arguments to configure flake8 + - --ignore=E402,E501,E203 \ No newline at end of file diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index ef497ddd..555f4920 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -1,3 +1,4 @@ +# flake8: noqa: F401 // to ignore the F401 (unused import) from .data_split import DataSplit from .data_subtype import DataSubtype from .data_type import DataType @@ -14,3 +15,4 @@ from .sort_by import SortBy from .sort_order import SortOrder from .response_status import ResponseStatus +from .database_source import DatabaseSourceType diff --git a/aixplain/enums/database_source.py b/aixplain/enums/database_source.py new file mode 100644 index 00000000..7c5eaa67 --- /dev/null +++ b/aixplain/enums/database_source.py @@ -0,0 +1,47 @@ +__author__ = "aiXplain" + +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Lucas Pavanelli and Thiago Castro Ferreira and Ahmet Gunduz +Date: March 7th 2025 +Description: + Database Source Type Enum +""" + +from enum import Enum + + +class DatabaseSourceType(Enum): + """Enum for database source types""" + + POSTGRESQL = "postgresql" + SQLITE = "sqlite" + CSV = "csv" + + @classmethod + def from_string(cls, source_type: str) -> "DatabaseSourceType": + """Convert string to DatabaseSourceType enum + + Args: + source_type (str): Source type string + + Returns: + DatabaseSourceType: Corresponding enum value + """ + try: + return cls[source_type.upper()] + except KeyError: + raise ValueError(f"Invalid source type: {source_type}") diff --git a/aixplain/enums/status.py b/aixplain/enums/status.py new file mode 100644 index 00000000..3c84f2be --- /dev/null +++ b/aixplain/enums/status.py @@ -0,0 +1,8 @@ +from enum import Enum +from typing import Text + + +class Status(Text, Enum): + FAILED = "failed" + IN_PROGRESS = "in_progress" + SUCCESS = "success" diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index e1f18a36..af59f17f 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -24,6 +24,7 @@ import json import logging import warnings +import os from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier @@ -32,7 +33,9 @@ from aixplain.modules.agent.tool.pipeline_tool import PipelineTool from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool -from aixplain.modules.agent.tool.sql_tool import SQLTool +from aixplain.modules.agent.tool.sql_tool import ( + SQLTool, +) from aixplain.modules.model import Model from aixplain.modules.pipeline import Pipeline from aixplain.utils import config @@ -121,8 +124,8 @@ def create( "status": "draft", "tasks": [task.to_dict() for task in tasks], } - agent = build_agent(payload=payload, api_key=api_key) - agent.validate() + agent = build_agent(payload=payload, tools=tools, api_key=api_key) + agent.validate(raise_exception=True) response = "Unspecified error" try: logging.debug(f"Start service for POST Create Agent - {url} - {headers} - {json.dumps(agent.to_dict())}") @@ -132,7 +135,7 @@ def create( raise Exception("Agent Onboarding Error: Please contact the administrators.") if 200 <= r.status_code < 300: - agent = build_agent(payload=response, api_key=api_key) + agent = build_agent(payload=response, tools=tools, api_key=api_key) else: error_msg = f"Agent Onboarding Error: {response}" if "message" in response: @@ -193,7 +196,8 @@ def create_custom_python_code_tool(cls, code: Union[Text, Callable], description def create_sql_tool( cls, description: Text, - database: Text, + source: str, + source_type: str, schema: Optional[Text] = None, tables: Optional[List[Text]] = None, enable_commit: bool = False, @@ -202,15 +206,114 @@ def create_sql_tool( Args: description (Text): description of the database tool - database (Text): URL/local path of the SQLite database file - schema (Optional[Text], optional): database schema description (optional) + source (Union[Text, Dict]): database source - can be a connection string or dictionary with connection details + source_type (Text): type of source (postgresql, sqlite, csv) + schema (Optional[Text], optional): database schema description tables (Optional[List[Text]], optional): table names to work with (optional) enable_commit (bool, optional): enable to modify the database (optional) - Returns: SQLTool: created SQLTool + + Examples: + # CSV - Simple + sql_tool = AgentFactory.create_sql_tool( + description="My CSV Tool", + source="/path/to/data.csv", + source_type="csv", + tables=["data"] + ) + + # SQLite - Simple + sql_tool = AgentFactory.create_sql_tool( + description="My SQLite Tool", + source="/path/to/database.sqlite", + source_type="sqlite", + tables=["users", "products"] + ) """ - return SQLTool(description=description, database=database, schema=schema, tables=tables, enable_commit=enable_commit) + from aixplain.modules.agent.tool.sql_tool import ( + SQLToolError, + create_database_from_csv, + get_table_schema, + get_table_names_from_schema, + ) + from aixplain.enums import DatabaseSourceType + + if not source: + raise SQLToolError("Source must be provided") + if not source_type: + raise SQLToolError("Source type must be provided") + + # Validate source type + try: + source_type = DatabaseSourceType.from_string(source_type) + except ValueError as e: + raise SQLToolError(str(e)) + + database_path = None # Final database path to pass to SQLTool + + # Handle CSV source type + if source_type == DatabaseSourceType.CSV: + if not os.path.exists(source): + raise SQLToolError(f"CSV file '{source}' does not exist") + if not source.endswith(".csv"): + raise SQLToolError(f"File '{source}' is not a CSV file") + + # Create database name from CSV filename or use custom table name + base_name = os.path.splitext(os.path.basename(source))[0] + db_path = os.path.join(os.path.dirname(source), f"{base_name}.db") + + try: + # Create database from CSV + schema = create_database_from_csv(source, db_path) + database_path = db_path + + # Get table names if not provided + if not tables: + tables = get_table_names_from_schema(schema) + + except Exception as e: + if os.path.exists(db_path): + try: + os.remove(db_path) + except Exception as cleanup_error: + warnings.warn(f"Failed to remove temporary database file '{db_path}': {str(cleanup_error)}") + raise SQLToolError(f"Failed to create database from CSV: {str(e)}") + + # Handle SQLite source type + elif source_type == DatabaseSourceType.SQLITE: + if not os.path.exists(source): + raise SQLToolError(f"Database '{source}' does not exist") + if not source.endswith(".db") and not source.endswith(".sqlite"): + raise SQLToolError(f"Database '{source}' must have .db or .sqlite extension") + + database_path = source + + # Infer schema from database if not provided + if not schema: + try: + schema = get_table_schema(database_path) + except Exception as e: + raise SQLToolError(f"Failed to get database schema: {str(e)}") + + # Get table names if not provided + if not tables: + try: + tables = get_table_names_from_schema(schema) + except Exception as e: + raise SQLToolError(f"Failed to get table names: {str(e)}") + + elif source_type == DatabaseSourceType.POSTGRESQL: + raise SQLToolError("PostgreSQL is not supported yet") + + # Create and return SQLTool + return SQLTool( + description=description, + database=database_path, + schema=schema, + tables=tables, + enable_commit=enable_commit, + ) @classmethod def list(cls) -> Dict: diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 6b96b8a9..d64ab773 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -1,69 +1,93 @@ __author__ = "thiagocastroferreira" +import logging import aixplain.utils.config as config from aixplain.enums import Function, Supplier from aixplain.enums.asset_status import AssetStatus from aixplain.modules.agent import Agent +from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_task import AgentTask from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.agent.tool.pipeline_tool import PipelineTool from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool from aixplain.modules.agent.tool.sql_tool import SQLTool -from typing import Dict, Text +from typing import Dict, Text, List from urllib.parse import urljoin GPT_4o_ID = "6646261c6eb563165658bbb1" -def build_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Agent: +def build_tool(tool: Dict): + """Build a tool from a dictionary. + + Args: + tool (Dict): Tool dictionary. + + Returns: + Tool: Tool object. + """ + if tool["type"] == "model": + supplier = "aixplain" + for supplier_ in Supplier: + if isinstance(tool["supplier"], str): + if tool["supplier"] is not None and tool["supplier"].lower() in [ + supplier_.value["code"].lower(), + supplier_.value["name"].lower(), + ]: + supplier = supplier_ + break + tool = ModelTool( + function=Function(tool.get("function", None)), + supplier=supplier, + version=tool["version"], + model=tool["assetId"], + description=tool.get("description", ""), + parameters=tool.get("parameters", None), + ) + elif tool["type"] == "pipeline": + tool = PipelineTool(description=tool["description"], pipeline=tool["assetId"]) + elif tool["type"] == "utility": + if tool.get("utilityCode", None) is not None: + tool = CustomPythonCodeTool(description=tool["description"], code=tool["utilityCode"]) + else: + tool = PythonInterpreterTool() + elif tool["type"] == "sql": + parameters = {parameter["name"]: parameter["value"] for parameter in tool.get("parameters", [])} + database = parameters.get("database") + schema = parameters.get("schema") + tables = parameters.get("tables", None) + tables = tables.split(",") if tables is not None else None + enable_commit = parameters.get("enable_commit", False) + tool = SQLTool( + description=tool["description"], database=database, schema=schema, tables=tables, enable_commit=enable_commit + ) + else: + raise Exception("Agent Creation Error: Tool type not supported.") + + return tool + + +def build_agent(payload: Dict, tools: List[Tool] = None, api_key: Text = config.TEAM_API_KEY) -> Agent: """Instantiate a new agent in the platform.""" tools_dict = payload["assets"] - tools = [] - for tool in tools_dict: - if tool["type"] == "model": - supplier = "aixplain" - for supplier_ in Supplier: - if isinstance(tool["supplier"], str): - if tool["supplier"] is not None and tool["supplier"].lower() in [ - supplier_.value["code"].lower(), - supplier_.value["name"].lower(), - ]: - supplier = supplier_ - break - tool = ModelTool( - function=Function(tool.get("function", None)), - supplier=supplier, - version=tool["version"], - model=tool["assetId"], - description=tool.get("description", ""), - parameters=tool.get("parameters", None), - ) - elif tool["type"] == "pipeline": - tool = PipelineTool(description=tool["description"], pipeline=tool["assetId"]) - elif tool["type"] == "utility": - if tool.get("utilityCode", None) is not None: - tool = CustomPythonCodeTool(description=tool["description"], code=tool["utilityCode"]) - else: - tool = PythonInterpreterTool() - elif tool["type"] == "sql": - parameters = {parameter["name"]: parameter["value"] for parameter in tool.get("parameters", [])} - database = parameters.get("database") - schema = parameters.get("schema") - tables = parameters.get("tables", None) - tables = tables.split(",") if tables is not None else None - enable_commit = parameters.get("enable_commit", False) - tool = SQLTool( - description=tool["description"], database=database, schema=schema, tables=tables, enable_commit=enable_commit - ) - else: - raise Exception("Agent Creation Error: Tool type not supported.") - tools.append(tool) + payload_tools = tools + if payload_tools is None: + payload_tools = [] + for tool in tools_dict: + try: + payload_tools.append(build_tool(tool)) + except Exception: + logging.warning( + f"Tool {tool['assetId']} is not available. Make sure it exists or you have access to it. " + "If you think this is an error, please contact the administrators." + ) + continue agent = Agent( id=payload["id"] if "id" in payload else "", name=payload.get("name", ""), - tools=tools, + tools=payload_tools, description=payload.get("description", ""), instructions=payload.get("role", ""), supplier=payload.get("teamId", None), diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index ea145d9a..e17841e6 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -45,27 +45,38 @@ def create( api_key: Text = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, + use_mentalist: bool = True, use_inspector: bool = True, - use_mentalist_and_inspector: bool = True, + use_mentalist_and_inspector: bool = False, # TODO: remove this ) -> TeamAgent: """Create a new team agent in the platform.""" assert len(agents) > 0, "TeamAgent Onboarding Error: At least one agent must be provided." + agent_list = [] for agent in agents: if isinstance(agent, Text) is True: try: from aixplain.factories.agent_factory import AgentFactory - agent = AgentFactory.get(agent) + agent_obj = AgentFactory.get(agent) except Exception: raise Exception(f"TeamAgent Onboarding Error: Agent {agent} does not exist.") else: from aixplain.modules.agent import Agent + agent_obj = agent + assert isinstance(agent, Agent), "TeamAgent Onboarding Error: Agents must be instances of Agent class" - - mentalist_and_inspector_llm_id = None - if use_inspector or use_mentalist_and_inspector: - mentalist_and_inspector_llm_id = llm_id + agent_list.append(agent_obj) + + if use_inspector and not use_mentalist: + raise Exception("TeamAgent Onboarding Error: To use the Inspector agent, you must enable Mentalist.") + + if use_mentalist_and_inspector: + mentalist_llm_id = llm_id + inspector_llm_id = llm_id + else: + mentalist_llm_id = llm_id if use_mentalist else None + inspector_llm_id = llm_id if use_inspector else None team_agent = None url = urljoin(config.BACKEND_URL, "sdk/agent-communities") @@ -76,26 +87,26 @@ def create( elif isinstance(supplier, Supplier): supplier = supplier.value["code"] - agent_list = [] + agent_payload_list = [] for idx, agent in enumerate(agents): - agent_list.append({"assetId": agent.id, "number": idx, "type": "AGENT", "label": "AGENT"}) + agent_payload_list.append({"assetId": agent.id, "number": idx, "type": "AGENT", "label": "AGENT"}) payload = { "name": name, - "agents": agent_list, + "agents": agent_payload_list, "links": [], "description": description, "llmId": llm_id, "supervisorId": llm_id, - "plannerId": mentalist_and_inspector_llm_id, - "inspectorId": mentalist_and_inspector_llm_id, + "plannerId": mentalist_llm_id, + "inspectorId": inspector_llm_id, "supplier": supplier, "version": version, "status": "draft", } - team_agent = build_team_agent(payload=payload, api_key=api_key) - team_agent.validate() + team_agent = build_team_agent(payload=payload, agents=agent_list, api_key=api_key) + team_agent.validate(raise_exception=True) response = "Unspecified error" try: logging.debug(f"Start service for POST Create TeamAgent - {url} - {headers} - {json.dumps(payload)}") @@ -105,7 +116,7 @@ def create( raise Exception(e) if 200 <= r.status_code < 300: - team_agent = build_team_agent(payload=response, api_key=api_key) + team_agent = build_team_agent(payload=response, agents=agent_list, api_key=api_key) else: error_msg = f"{response}" if "message" in response: diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index 72afa1d1..5e865cd0 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -1,34 +1,45 @@ __author__ = "lucaspavanelli" +import logging import aixplain.utils.config as config from aixplain.enums.asset_status import AssetStatus +from aixplain.modules.agent import Agent from aixplain.modules.team_agent import TeamAgent -from typing import Dict, Text +from typing import Dict, Text, List from urllib.parse import urljoin GPT_4o_ID = "6646261c6eb563165658bbb1" -def build_team_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> TeamAgent: +def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = config.TEAM_API_KEY) -> TeamAgent: """Instantiate a new team agent in the platform.""" from aixplain.factories.agent_factory import AgentFactory agents_dict = payload["agents"] - agents = [] - for i, agent in enumerate(agents_dict): - agent = AgentFactory.get(agent["assetId"]) - agents.append(agent) + payload_agents = agents + if payload_agents is None: + payload_agents = [] + for i, agent in enumerate(agents_dict): + try: + payload_agents.append(AgentFactory.get(agent["assetId"])) + except Exception: + logging.warning( + f"Agent {agent['assetId']} not found. Make sure it exists or you have access to it. " + "If you think this is an error, please contact the administrators." + ) + continue team_agent = TeamAgent( id=payload.get("id", ""), name=payload.get("name", ""), - agents=agents, + agents=payload_agents, description=payload.get("description", ""), supplier=payload.get("teamId", None), version=payload.get("version", None), cost=payload.get("cost", None), llm_id=payload.get("llmId", GPT_4o_ID), - use_mentalist_and_inspector=True if payload["plannerId"] is not None else False, + use_mentalist=True if payload.get("plannerId", None) is not None else False, + use_inspector=True if payload.get("inspectorId", None) is not None else False, api_key=api_key, status=AssetStatus(payload["status"]), ) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 83844c3c..c0209313 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -62,6 +62,8 @@ class Agent(Model): cost (Dict, optional): model price. Defaults to None. """ + is_valid: bool + def __init__( self, id: Text, @@ -107,8 +109,9 @@ def __init__( status = AssetStatus.DRAFT self.status = status self.tasks = tasks + self.is_valid = True - def validate(self) -> None: + def _validate(self) -> None: """Validate the Agent.""" from aixplain.factories.model_factory import ModelFactory @@ -119,15 +122,36 @@ def validate(self) -> None: try: llm = ModelFactory.get(self.llm_id, api_key=self.api_key) - assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." except Exception: raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.") + assert ( + llm.function == Function.TEXT_GENERATION + ), "Large Language Model must be a text generation model." + for tool in self.tools: if isinstance(tool, Tool): tool.validate() elif isinstance(tool, Model): - assert not isinstance(tool, Agent), "Agent cannot contain another Agent." + assert not isinstance( + tool, Agent + ), "Agent cannot contain another Agent." + + def validate(self, raise_exception: bool = False) -> bool: + """Validate the Agent.""" + try: + self._validate() + self.is_valid = True + except Exception as e: + self.is_valid = False + if raise_exception: + raise e + else: + logging.warning(f"Agent Validation Error: {e}") + logging.warning( + "You won't be able to run the Agent until the issues are handled manually." + ) + return self.is_valid def run( self, @@ -183,7 +207,9 @@ def run( return response poll_url = response["url"] end = time.time() - result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + result = self.sync_poll( + poll_url, name=name, timeout=timeout, wait_time=wait_time + ) result_data = result.data return AgentResponse( status=ResponseStatus.SUCCESS, @@ -245,10 +271,19 @@ def run_async( """ from aixplain.factories.file_factory import FileFactory - assert data is not None or query is not None, "Either 'data' or 'query' must be provided." + if not self.is_valid: + raise Exception( + "Agent is not valid. Please validate the agent before running." + ) + + assert ( + data is not None or query is not None + ), "Either 'data' or 'query' must be provided." if data is not None: if isinstance(data, dict): - assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided." + assert ( + "query" in data and data["query"] is not None + ), "When providing a dictionary, 'query' must be provided." query = data.get("query") if session_id is None: session_id = data.get("session_id") @@ -261,7 +296,9 @@ def run_async( # process content inputs if content is not None: - assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text." + assert ( + FileFactory.check_storage_type(query) == StorageType.TEXT + ), "When providing 'content', query must be text." if isinstance(content, list): assert len(content) <= 3, "The maximum number of content inputs is 3." @@ -270,7 +307,9 @@ def run_async( query += f"\n{input_link}" elif isinstance(content, dict): for key, value in content.items(): - assert "{{" + key + "}}" in query, f"Key '{key}' not found in query." + assert ( + "{{" + key + "}}" in query + ), f"Key '{key}' not found in query." value = FileFactory.to_link(value) query = query.replace("{{" + key + "}}", f"'{value}'") @@ -285,8 +324,16 @@ def run_async( "sessionId": session_id, "history": history, "executionParams": { - "maxTokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, - "maxIterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations, + "maxTokens": ( + parameters["max_tokens"] + if "max_tokens" in parameters + else max_tokens + ), + "maxIterations": ( + parameters["max_iterations"] + if "max_iterations" in parameters + else max_iterations + ), "outputFormat": output_format.value, }, } @@ -320,7 +367,11 @@ def to_dict(self) -> Dict: "assets": [tool.to_dict() for tool in self.tools], "description": self.description, "role": self.instructions, - "supplier": self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier, + "supplier": ( + self.supplier.value["code"] + if isinstance(self.supplier, Supplier) + else self.supplier + ), "version": self.version, "llmId": self.llm_id, "status": self.status.value, @@ -331,7 +382,10 @@ def delete(self) -> None: """Delete Agent service""" try: url = urljoin(config.BACKEND_URL, f"sdk/agents/{self.id}") - headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + headers = { + "x-api-key": config.TEAM_API_KEY, + "Content-Type": "application/json", + } logging.debug(f"Start service for DELETE Agent - {url} - {headers}") r = _request_with_retry("delete", url, headers=headers) logging.debug(f"Result of request for DELETE Agent - {r.status_code}") @@ -355,19 +409,22 @@ def update(self) -> None: stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " + "Please use save() instead.", DeprecationWarning, stacklevel=2, ) from aixplain.factories.agent_factory.utils import build_agent - self.validate() + self.validate(raise_exception=True) url = urljoin(config.BACKEND_URL, f"sdk/agents/{self.id}") headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} payload = self.to_dict() - logging.debug(f"Start service for PUT Update Agent - {url} - {headers} - {json.dumps(payload)}") + logging.debug( + f"Start service for PUT Update Agent - {url} - {headers} - {json.dumps(payload)}" + ) resp = "No specified error." try: r = _request_with_retry("put", url, headers=headers, json=payload) @@ -386,7 +443,9 @@ def save(self) -> None: self.update() def deploy(self) -> None: - assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed." + assert ( + self.status == AssetStatus.DRAFT + ), "Agent must be in draft status to be deployed." assert self.status != AssetStatus.ONBOARDED, "Agent is already deployed." self.status = AssetStatus.ONBOARDED self.update() diff --git a/aixplain/modules/agent/agent_response.py b/aixplain/modules/agent/agent_response.py index 9ece7aa7..73c5e839 100644 --- a/aixplain/modules/agent/agent_response.py +++ b/aixplain/modules/agent/agent_response.py @@ -52,5 +52,5 @@ def to_dict(self) -> Dict[Text, Any]: return base_dict def __repr__(self) -> str: - fields = super().__repr__().strip("ModelResponse(").rstrip(")") + fields = super().__repr__()[len("ModelResponse(") : -1] return f"AgentResponse({fields})" diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 5175b1b4..f0cb88e7 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -114,6 +114,9 @@ def to_dict(self) -> Dict: def validate(self) -> Model: from aixplain.factories.model_factory import ModelFactory + if self.model_object is not None: + return self.model_object + try: model = None if self.model is not None: diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index eb2ebdde..6935e6df 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -21,12 +21,229 @@ Agentification Class """ import os +import warnings import validators +import pandas as pd +import numpy as np from typing import Text, Optional, Dict, List, Union +import sqlite3 from aixplain.modules.agent.tool import Tool +class SQLToolError(Exception): + """Base exception for SQL Tool errors""" + + pass + + +class CSVError(SQLToolError): + """Exception for CSV-related errors""" + + pass + + +class DatabaseError(SQLToolError): + """Exception for database-related errors""" + + pass + + +def clean_column_name(col: Text) -> Text: + """Clean column names by replacing spaces and special characters with underscores""" + # Replace special characters with underscores + cleaned = col.strip().lower() + cleaned = "".join(c if c.isalnum() else "_" for c in cleaned) + # Remove multiple consecutive underscores + while "__" in cleaned: + cleaned = cleaned.replace("__", "_") + # Remove leading/trailing underscores + cleaned = cleaned.strip("_") + + # Add 'col_' prefix to columns that start with numbers + if cleaned[0].isdigit(): + cleaned = "col_" + cleaned + + return cleaned + + +def check_duplicate_columns(df: pd.DataFrame) -> None: + """Check for duplicate column names in DataFrame and raise CSVError if found""" + # Get all column names + columns = df.columns.tolist() + # Get cleaned column names + cleaned_columns = [clean_column_name(col) for col in columns] + + # Check for duplicates in cleaned names + seen = set() + duplicates = [] + + for original, cleaned in zip(columns, cleaned_columns): + if cleaned in seen: + duplicates.append(original) + seen.add(cleaned) + + if duplicates: + raise CSVError(f"CSV file contains duplicate column names after cleaning: {', '.join(duplicates)}") + + +def infer_sqlite_type(dtype) -> Text: + """Infer SQLite type from pandas dtype""" + if pd.api.types.is_integer_dtype(dtype): + return "INTEGER" + elif pd.api.types.is_float_dtype(dtype): + return "REAL" + elif pd.api.types.is_bool_dtype(dtype): + return "INTEGER" + elif pd.api.types.is_datetime64_any_dtype(dtype): + return "TIMESTAMP" + else: + warnings.warn(f"Column with dtype '{dtype}' will be stored as TEXT in SQLite") + return "TEXT" + + +def get_table_schema(database_path: str) -> str: + """Get the schema of all tables in the database""" + if not os.path.exists(database_path): + raise DatabaseError(f"Database file '{database_path}' does not exist") + + try: + with sqlite3.connect(database_path) as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT sql + FROM sqlite_master + WHERE type='table' AND sql IS NOT NULL + """ + ) + schemas = cursor.fetchall() + if not schemas: + warnings.warn(f"No tables found in database '{database_path}'") + return "\n".join(schema[0] for schema in schemas if schema[0]) + except sqlite3.Error as e: + raise DatabaseError(f"Failed to get table schema: {str(e)}") + except Exception as e: + raise DatabaseError(f"Unexpected error while getting table schema: {str(e)}") + + +def create_database_from_csv(csv_path: str, database_path: str) -> str: + """Create SQLite database from CSV file and return the schema""" + if not os.path.exists(csv_path): + raise CSVError(f"CSV file '{csv_path}' does not exist") + if not csv_path.endswith(".csv"): + raise CSVError(f"File '{csv_path}' is not a CSV file") + + try: + # Load CSV file + df = pd.read_csv(csv_path) + + if df.empty: + raise CSVError(f"CSV file '{csv_path}' is empty") + + # Clean column names and track changes + original_columns = df.columns.tolist() + cleaned_columns = [clean_column_name(col) for col in original_columns] + changed_columns = [(orig, cleaned) for orig, cleaned in zip(original_columns, cleaned_columns) if orig != cleaned] + + if changed_columns: + changes = ", ".join([f"'{orig}' to '{cleaned}'" for orig, cleaned in changed_columns]) + warnings.warn(f"Column names were cleaned for SQLite compatibility: {changes}") + + df.columns = cleaned_columns + + # Connect to SQLite database + if os.path.exists(database_path): + warnings.warn(f"Database '{database_path}' already exists and will be modified") + + try: + with sqlite3.connect(database_path) as conn: + cursor = conn.cursor() + + # Check if table already exists + table_name = clean_column_name(os.path.splitext(os.path.basename(csv_path))[0]) + cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'") + if cursor.fetchone(): + warnings.warn(f"Table '{table_name}' already exists in the database and will be replaced") + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + + # Create column definitions for SQL table + cols_definitions = [] + for col, dtype in df.dtypes.items(): + col_type = infer_sqlite_type(dtype) + cols_definitions.append(f'"{col}" {col_type}') + + table_columns = ", ".join(cols_definitions) + + # Create the table + create_table_query = f""" + CREATE TABLE {table_name} ( + {table_columns} + ) + """ + cursor.execute(create_table_query) + + # Insert data into table + total_rows = len(df) + for idx, row in enumerate(df.itertuples(index=False), 1): + # Convert all data to Python native types that SQLite can handle + row_data = [] + for val in row: + if pd.isna(val): + row_data.append(None) + elif isinstance(val, pd.Timestamp): + row_data.append(val.strftime("%Y-%m-%d %H:%M:%S")) + elif isinstance(val, (np.int64, np.int32)): + row_data.append(int(val)) + elif isinstance(val, (np.float64, np.float32)): + row_data.append(float(val)) + else: + row_data.append(str(val)) + + placeholders = ", ".join(["?" for _ in df.columns]) + column_names = ", ".join([f'"{col}"' for col in df.columns]) + insert_query = f"INSERT INTO {table_name} ({column_names}) VALUES ({placeholders})" + try: + cursor.execute(insert_query, tuple(row_data)) + except sqlite3.Error as e: + raise DatabaseError(f"Error inserting row {idx}/{total_rows}: {str(e)}") + + conn.commit() + + except sqlite3.Error as e: + raise DatabaseError(f"SQLite error: {str(e)}") + + # Get the schema + try: + return get_table_schema(database_path) + except DatabaseError as e: + raise DatabaseError(f"Failed to get schema after creating database: {str(e)}") + + except pd.errors.EmptyDataError: + raise CSVError(f"CSV file '{csv_path}' is empty") + except pd.errors.ParserError as e: + raise CSVError(f"Failed to parse CSV file: {str(e)}") + except Exception as e: + if isinstance(e, CSVError): + raise e + raise CSVError(f"Unexpected error while processing CSV file: {str(e)}") + + +def get_table_names_from_schema(schema: str) -> List[str]: + """Extract table names from schema string""" + if not schema: + return [] + + table_names = [] + for line in schema.split("\n"): + line = line.strip() + if line.startswith("CREATE TABLE"): + # Extract table name from CREATE TABLE statement + table_name = line.split("CREATE TABLE")[1].strip().split("(")[0].strip().strip("\"'") + table_names.append(table_name) + return table_names + + class SQLTool(Tool): """Tool to execute SQL commands in an SQLite database. @@ -42,7 +259,7 @@ def __init__( self, description: Text, database: Text, - schema: Text, + schema: Optional[Text] = None, tables: Optional[Union[List[Text], Text]] = None, enable_commit: bool = False, **additional_info, @@ -51,8 +268,8 @@ def __init__( Args: description (Text): description of the tool - database (Text): database name - schema (Text): database schema description + database (Text): database uri + schema (Optional[Text]): database schema description tables (Optional[Union[List[Text], Text]]): table names to work with (optional) enable_commit (bool): enable to modify the database (optional) """ @@ -77,16 +294,39 @@ def to_dict(self) -> Dict[str, Text]: def validate(self): from aixplain.factories.file_factory import FileFactory - assert self.description and self.description.strip() != "", "SQL Tool Error: Description is required" - assert self.database and self.database.strip() != "", "SQL Tool Error: Database is required" + if not self.description or self.description.strip() == "": + raise SQLToolError("Description is required") + if not self.database: + raise SQLToolError("Database must be provided") + + # Handle database validation if not ( str(self.database).startswith("s3://") - or str(self.database).startswith("http://") - or str(self.database).startswith("https://") - or validators.url(self.database) + or str(self.database).startswith("http://") # noqa: W503 + or str(self.database).startswith("https://") # noqa: W503 + or validators.url(self.database) # noqa: W503 ): if not os.path.exists(self.database): - raise Exception(f"SQL Tool Error: Database '{self.database}' does not exist") + raise SQLToolError(f"Database '{self.database}' does not exist") if not self.database.endswith(".db"): - raise Exception(f"SQL Tool Error: Database '{self.database}' must have .db extension") - self.database = FileFactory.upload(local_path=self.database, is_temp=True) + raise SQLToolError(f"Database '{self.database}' must have .db extension") + + # Infer schema from database if not provided + if not self.schema: + try: + self.schema = get_table_schema(self.database) + except DatabaseError as e: + raise SQLToolError(f"Failed to get database schema: {str(e)}") + + # Set tables if not already set + if not self.tables: + try: + self.tables = get_table_names_from_schema(self.schema) + except Exception as e: + raise SQLToolError(f"Failed to set tables: {str(e)}") + + # Upload database + try: + self.database = FileFactory.upload(local_path=self.database, is_temp=True) + except Exception as e: + raise SQLToolError(f"Failed to upload database: {str(e)}") diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 3af31048..7b494305 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -23,6 +23,7 @@ def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None): payload = str(payload) payload = {"data": payload} except Exception: + parameters["data"] = data payload = {"data": data} payload.update(parameters) payload = json.dumps(payload) diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index f2deb9df..cd47da87 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -26,11 +26,13 @@ import os import logging from aixplain.enums.asset_status import AssetStatus +from aixplain.enums.response_status import ResponseStatus from aixplain.modules.asset import Asset from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin +from aixplain.modules.pipeline.response import PipelineResponse class Pipeline(Asset): @@ -107,22 +109,21 @@ def __polling( # TO DO: wait_time = to the longest path of the pipeline * minimum waiting time logging.debug(f"Polling for Pipeline: Start polling for {name} ") start, end = time.time(), time.time() - completed = False - response_body = {"status": "FAILED"} - while not completed and (end - start) < timeout: + response_body = {"status": ResponseStatus.FAILED, "completed": False} + + while not response_body["completed"] and (end - start) < timeout: try: response_body = self.poll(poll_url, name=name) logging.debug(f"Polling for Pipeline: Status of polling for {name} : {response_body}") - completed = response_body["completed"] - end = time.time() - if completed is False: + if not response_body["completed"]: time.sleep(wait_time) if wait_time < 60: wait_time *= 1.1 except Exception: logging.error(f"Polling for Pipeline: polling for {name} : Continue") - if response_body and response_body["status"] == "SUCCESS": + break + if response_body["status"] == ResponseStatus.SUCCESS: try: logging.debug(f"Polling for Pipeline: Final status of polling for {name} : SUCCESS - {response_body}") except Exception: @@ -133,7 +134,9 @@ def __polling( ) return response_body - def poll(self, poll_url: Text, name: Text = "pipeline_process") -> Dict: + def poll( + self, poll_url: Text, name: Text = "pipeline_process", response_version: Text = "v2" + ) -> Union[Dict, PipelineResponse]: """Poll the platform to check whether an asynchronous call is done. Args: @@ -157,38 +160,24 @@ def poll(self, poll_url: Text, name: Text = "pipeline_process") -> Dict: except Exception: resp = r.json() logging.info(f"Single Poll for Pipeline: Status of polling for {name} : {resp}") - except Exception: - resp = {"status": "FAILED"} - return resp - - def _should_fallback_to_v2(self, response: Dict, version: str) -> bool: - """Determine if the pipeline should fallback to version 2.0 based on the response. - - Args: - response (Dict): The response from the pipeline call. - version (str): The version of the pipeline being used. - - Returns: - bool: True if fallback is needed, False otherwise. - """ - # If the version is not 3.0, no fallback is needed - if version != self.VERSION_3_0: - return False - - should_fallback = False - if "status" not in response or response["status"] == "FAILED": - should_fallback = True - elif response["status"] == "SUCCESS" and ("data" not in response or not response["data"]): - should_fallback = True - # Check for conditions that require a fallback - - if should_fallback: - logging.warning( - f"Pipeline Run Error: Failed to run pipeline {self.id} with version {version}. " - f"Trying with version {self.VERSION_2_0}." + if response_version == "v1": + return resp + status = ResponseStatus(resp.pop("status", "failed")) + response = PipelineResponse( + status=status, + error=resp.pop("error", None), + elapsed_time=resp.pop("elapsed_time", 0), + **resp, ) + return response - return should_fallback + except Exception: + return PipelineResponse( + status=ResponseStatus.FAILED, + error=resp.pop("error", None), + elapsed_time=resp.pop("elapsed_time", 0), + **resp, + ) def run( self, @@ -197,55 +186,62 @@ def run( name: Text = "pipeline_process", timeout: float = 20000.0, wait_time: float = 1.0, - batch_mode: bool = True, - version: str = None, + version: Optional[Text] = None, + response_version: Text = "v2", **kwargs, - ) -> Dict: - """Runs a pipeline call. - - Args: - data (Union[Text, Dict]): link to the input data - data_asset (Optional[Union[Text, Dict]], optional): Data asset to be processed by the pipeline. Defaults to None. - name (Text, optional): ID given to a call. Defaults to "pipeline_process". - timeout (float, optional): total polling time. Defaults to 20000.0. - wait_time (float, optional): wait time in seconds between polling calls. Defaults to 1.0. - batch_mode (bool, optional): Whether to run the pipeline in batch mode or online. Defaults to True. - kwargs: A dictionary of keyword arguments. The keys are the argument names - - Returns: - Dict: parsed output from pipeline - """ - version = version or self.VERSION_3_0 + ) -> Union[Dict, PipelineResponse]: start = time.time() try: - response = self.run_async( - data, - data_asset=data_asset, - name=name, - batch_mode=batch_mode, - **kwargs, - ) - - if response["status"] == "FAILED": + response = self.run_async(data, data_asset=data_asset, name=name, version=version, **kwargs) + if response["status"] == ResponseStatus.FAILED: end = time.time() - response["elapsed_time"] = end - start - return response - + if response_version == "v1": + return { + "status": "failed", + "error": response.get("error", "ERROR"), + "elapsed_time": end - start, + **kwargs, + } + return PipelineResponse( + status=ResponseStatus.FAILED, + error={"error": response.get("error", "ERROR"), "status": "ERROR"}, + elapsed_time=end - start, + **kwargs, + ) poll_url = response["url"] + polling_response = self.__polling(poll_url, name=name, timeout=timeout, wait_time=wait_time) end = time.time() - response = self.__polling(poll_url, name=name, timeout=timeout, wait_time=wait_time) - return response + status = ResponseStatus(polling_response["status"]) + if response_version == "v1": + polling_response["elapsed_time"] = end - start + return polling_response + status = ResponseStatus(polling_response.status) + return PipelineResponse( + status=status, + error=polling_response.error, + elapsed_time=end - start, + data=getattr(polling_response, "data", {}), + **kwargs, + ) + except Exception as e: error_message = f"Error in request for {name}: {str(e)}" logging.error(error_message) logging.exception(error_message) end = time.time() - return { - "status": "FAILED", - "error": error_message, - "elapsed_time": end - start, - "version": version, - } + if response_version == "v1": + return { + "status": "failed", + "error": error_message, + "elapsed_time": end - start, + **kwargs, + } + return PipelineResponse( + status=ResponseStatus.FAILED, + error={"error": error_message, "status": "ERROR"}, + elapsed_time=end - start, + **kwargs, + ) def __prepare_payload( self, @@ -361,7 +357,8 @@ def run_async( data_asset: Optional[Union[Text, Dict]] = None, name: Text = "pipeline_process", batch_mode: bool = True, - version: str = None, + version: Optional[Text] = None, + response_version: Text = "v2", **kwargs, ) -> Dict: """Runs asynchronously a pipeline call. @@ -371,12 +368,13 @@ def run_async( data_asset (Optional[Union[Text, Dict]], optional): Data asset to be processed by the pipeline. Defaults to None. name (Text, optional): ID given to a call. Defaults to "pipeline_process". batch_mode (bool, optional): Whether to run the pipeline in batch mode or online. Defaults to True. + version (Optional[Text], optional): Version of the pipeline. Defaults to None. + response_version (Text, optional): Version of the response. Defaults to "v2". kwargs: A dictionary of keyword arguments. The keys are the argument names Returns: Dict: polling URL in response """ - version = version or self.VERSION_3_0 headers = { "x-api-key": self.api_key, "Content-Type": "application/json", @@ -390,14 +388,21 @@ def run_async( call_url = f"{self.url}/{self.id}" logging.info(f"Start service for {name} - {call_url} - {payload}") r = _request_with_retry("post", call_url, headers=headers, data=payload) - resp = None try: if 200 <= r.status_code < 300: resp = r.json() logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - poll_url = resp["url"] - response = {"status": "IN_PROGRESS", "url": poll_url} + if response_version == "v1": + return resp + res = PipelineResponse( + status=ResponseStatus(resp.pop("status", "failed")), + url=resp["url"], + elapsed_time=None, + **kwargs, + ) + return res + else: if r.status_code == 401: error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." @@ -414,14 +419,35 @@ def run_async( error = ( f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." ) - response = {"status": "FAILED", "error_message": error} - logging.error(f"Error in request for {name} - {r.status_code}: {error}") - except Exception: - response = {"status": "FAILED"} - if resp is not None: - response["error"] = resp - return response + logging.error(f"Error in request for {name} - {r.status_code}: {error}") + if response_version == "v1": + return { + "status": "failed", + "error": error, + "elapsed_time": None, + **kwargs, + } + return PipelineResponse( + status=ResponseStatus.FAILED, + error={"error": error, "status": "ERROR"}, + elapsed_time=None, + **kwargs, + ) + except Exception as e: + if response_version == "v1": + return { + "status": "failed", + "error": str(e), + "elapsed_time": None, + **kwargs, + } + return PipelineResponse( + status=ResponseStatus.FAILED, + error={"error": str(e), "status": "ERROR"}, + elapsed_time=None, + **kwargs, + ) def update( self, diff --git a/aixplain/modules/pipeline/designer/base.py b/aixplain/modules/pipeline/designer/base.py index 55300467..063accad 100644 --- a/aixplain/modules/pipeline/designer/base.py +++ b/aixplain/modules/pipeline/designer/base.py @@ -271,37 +271,54 @@ def create_param( param.node = self.node return param - def __getitem__(self, code: str) -> Param: + def __getattr__(self, code: str) -> Param: + if code == "_params": + raise AttributeError("Attribute '_params' is not accessible") for param in self._params: if param.code == code: return param - raise KeyError(f"Parameter with code '{code}' not found.") + raise AttributeError(f"Attribute with code '{code}' not found.") + + def __getitem__(self, code: str) -> Param: + try: + return getattr(self, code) + except AttributeError: + raise KeyError(f"Parameter with code '{code}' not found.") def special_prompt_handling(self, code: str, value: str) -> None: """ This method will handle the special prompt handling for asset nodes having `text-generation` function type. """ + prompt_param = getattr(self, "prompt", None) + if prompt_param: + raise ValueError("Prompt param already exists") + from .nodes import AssetNode - if isinstance(self.node, AssetNode) and self.node.asset.function == "text-generation": - if code == "prompt": - matches = find_prompt_params(value) - for match in matches: - self.node.inputs.create_param(match, DataType.TEXT, is_required=True) + if not isinstance(self.node, AssetNode): + return - def set_param_value(self, code: str, value: str) -> None: - self.special_prompt_handling(code, value) - self[code].value = value + if not hasattr(self.node, "asset") or self.node.asset.function != "text-generation": + return + + matches = find_prompt_params(value) + for match in matches: + if match in self: + raise ValueError(f"Prompt param with code '{match}' already exists") + + self.node.inputs.create_param(match, DataType.TEXT, is_required=True) def __setitem__(self, code: str, value: str) -> None: - # set param value on set item to avoid setting it manually - self.set_param_value(code, value) + setattr(self, code, value) def __setattr__(self, name: str, value: any) -> None: - # set param value on attribute assignment to avoid setting it manually - if isinstance(value, str) and hasattr(self, name): - self.set_param_value(name, value) + if name == "prompt": + self.special_prompt_handling(name, value) + + param = getattr(self, name, None) + if param and isinstance(param, Param): + param.value = value else: super().__setattr__(name, value) diff --git a/aixplain/modules/pipeline/response.py b/aixplain/modules/pipeline/response.py new file mode 100644 index 00000000..deda345d --- /dev/null +++ b/aixplain/modules/pipeline/response.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import Any, Optional, Dict, Text +from aixplain.enums import ResponseStatus + + +@dataclass +class PipelineResponse: + def __init__( + self, + status: ResponseStatus, + error: Optional[Dict[str, Any]] = None, + elapsed_time: Optional[float] = 0.0, + data: Optional[Text] = None, + url: Optional[Text] = "", + **kwargs, + ): + self.status = status + self.error = error + self.elapsed_time = elapsed_time + self.data = data + self.additional_fields = kwargs + self.url = url + + def __getattr__(self, key: str) -> Any: + if self.additional_fields and key in self.additional_fields: + return self.additional_fields[key] + + raise AttributeError() + + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __repr__(self) -> str: + fields = [] + if self.status: + fields.append(f"status={self.status}") + if self.error: + fields.append(f"error={self.error}") + if self.elapsed_time is not None: + fields.append(f"elapsed_time={self.elapsed_time}") + if self.data: + fields.append(f"data={self.data}") + if self.additional_fields: + fields.extend([f"{k}={repr(v)}" for k, v in self.additional_fields.items()]) + return f"PipelineResponse({', '.join(fields)})" + + def __contains__(self, key: str) -> bool: + return hasattr(self, key) diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index edb737d4..825967e8 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -61,6 +61,8 @@ class TeamAgent(Model): use_mentalist_and_inspector (bool): Use Mentalist and Inspector tools. Defaults to True. """ + is_valid: bool + def __init__( self, id: Text, @@ -72,7 +74,8 @@ def __init__( supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, cost: Optional[Dict] = None, - use_mentalist_and_inspector: bool = True, + use_mentalist: bool = True, + use_inspector: bool = True, status: AssetStatus = AssetStatus.DRAFT, **additional_info, ) -> None: @@ -95,13 +98,16 @@ def __init__( self.additional_info = additional_info self.agents = agents self.llm_id = llm_id - self.use_mentalist_and_inspector = use_mentalist_and_inspector + self.use_mentalist = use_mentalist + self.use_inspector = use_inspector + if isinstance(status, str): try: status = AssetStatus(status) except Exception: status = AssetStatus.DRAFT self.status = status + self.is_valid = True def run( self, @@ -156,12 +162,18 @@ def run( return response poll_url = response["url"] end = time.time() - response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + response = self.sync_poll( + poll_url, name=name, timeout=timeout, wait_time=wait_time + ) return response except Exception as e: logging.error(f"Team Agent Run: Error in running for {name}: {e}") end = time.time() - return AgentResponse(status=ResponseStatus.FAILED, completed=False, error_message="No response from the service.") + return AgentResponse( + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", + ) def run_async( self, @@ -194,10 +206,19 @@ def run_async( """ from aixplain.factories.file_factory import FileFactory - assert data is not None or query is not None, "Either 'data' or 'query' must be provided." + if not self.is_valid: + raise Exception( + "Team Agent is not valid. Please validate the team agent before running." + ) + + assert ( + data is not None or query is not None + ), "Either 'data' or 'query' must be provided." if data is not None: if isinstance(data, dict): - assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided." + assert ( + "query" in data and data["query"] is not None + ), "When providing a dictionary, 'query' must be provided." if session_id is None: session_id = data.pop("session_id", None) if history is None: @@ -211,7 +232,8 @@ def run_async( # process content inputs if content is not None: assert ( - isinstance(query, str) and FileFactory.check_storage_type(query) == StorageType.TEXT + isinstance(query, str) + and FileFactory.check_storage_type(query) == StorageType.TEXT ), "When providing 'content', query must be text." if isinstance(content, list): @@ -221,7 +243,9 @@ def run_async( query += f"\n{input_link}" elif isinstance(content, dict): for key, value in content.items(): - assert "{{" + key + "}}" in query, f"Key '{key}' not found in query." + assert ( + "{{" + key + "}}" in query + ), f"Key '{key}' not found in query." value = FileFactory.to_link(value) query = query.replace("{{" + key + "}}", f"'{value}'") @@ -236,8 +260,16 @@ def run_async( "sessionId": session_id, "history": history, "executionParams": { - "maxTokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, - "maxIterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations, + "maxTokens": ( + parameters["max_tokens"] + if "max_tokens" in parameters + else max_tokens + ), + "maxIterations": ( + parameters["max_iterations"] + if "max_iterations" in parameters + else max_iterations + ), "outputFormat": output_format.value, }, } @@ -245,7 +277,9 @@ def run_async( payload = json.dumps(payload) r = _request_with_retry("post", self.url, headers=headers, data=payload) - logging.info(f"Team Agent Run Async: Start service for {name} - {self.url} - {payload} - {headers}") + logging.info( + f"Team Agent Run Async: Start service for {name} - {self.url} - {payload} - {headers}" + ) resp = None try: @@ -266,15 +300,16 @@ def delete(self) -> None: """Delete Corpus service""" try: url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/{self.id}") - headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + headers = { + "x-api-key": config.TEAM_API_KEY, + "Content-Type": "application/json", + } logging.debug(f"Start service for DELETE Team Agent - {url} - {headers}") r = _request_with_retry("delete", url, headers=headers) if r.status_code != 200: raise Exception() except Exception: - message = ( - f"Team Agent Deletion Error (HTTP {r.status_code}): Make sure the Team Agent exists and you are the owner." - ) + message = f"Team Agent Deletion Error (HTTP {r.status_code}): Make sure the Team Agent exists and you are the owner." logging.error(message) raise Exception(f"{message}") @@ -283,19 +318,21 @@ def to_dict(self) -> Dict: "id": self.id, "name": self.name, "agents": [ - {"assetId": agent.id, "number": idx, "type": "AGENT", "label": "AGENT"} for idx, agent in enumerate(self.agents) + {"assetId": agent.id, "number": idx, "type": "AGENT", "label": "AGENT"} + for idx, agent in enumerate(self.agents) ], "links": [], "description": self.description, "llmId": self.llm_id, "supervisorId": self.llm_id, - "plannerId": self.llm_id if self.use_mentalist_and_inspector else None, + "plannerId": self.llm_id if self.use_mentalist else None, + "inspectorId": self.llm_id if self.use_inspector else None, "supplier": self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier, "version": self.version, "status": self.status.value, } - def validate(self) -> None: + def _validate(self) -> None: """Validate the Team.""" from aixplain.factories.model_factory import ModelFactory @@ -306,12 +343,30 @@ def validate(self) -> None: try: llm = ModelFactory.get(self.llm_id) - assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." + assert ( + llm.function == Function.TEXT_GENERATION + ), "Large Language Model must be a text generation model." except Exception: raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.") for agent in self.agents: - agent.validate() + agent.validate(raise_exception=True) + + def validate(self, raise_exception: bool = False) -> bool: + try: + self._validate() + self.is_valid = True + except Exception as e: + self.is_valid = False + if raise_exception: + raise e + else: + logging.warning(f"Team Agent Validation Error: {e}") + logging.warning( + "You won't be able to run the Team Agent until the issues are handled manually." + ) + + return self.is_valid def update(self) -> None: """Update the Team Agent.""" @@ -322,25 +377,30 @@ def update(self) -> None: stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " + "Please use save() instead.", DeprecationWarning, stacklevel=2, ) from aixplain.factories.team_agent_factory.utils import build_team_agent - self.validate() + self.validate(raise_exception=True) url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/{self.id}") headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} payload = self.to_dict() - logging.debug(f"Start service for PUT Update Team Agent - {url} - {headers} - {json.dumps(payload)}") + logging.debug( + f"Start service for PUT Update Team Agent - {url} - {headers} - {json.dumps(payload)}" + ) resp = "No specified error." try: r = _request_with_retry("put", url, headers=headers, json=payload) resp = r.json() except Exception: - raise Exception("Team Agent Update Error: Please contact the administrators.") + raise Exception( + "Team Agent Update Error: Please contact the administrators." + ) if 200 <= r.status_code < 300: return build_team_agent(resp) @@ -354,7 +414,11 @@ def save(self) -> None: def deploy(self) -> None: """Deploy the Team Agent.""" - assert self.status == AssetStatus.DRAFT, "Team Agent Deployment Error: Team Agent must be in draft status." - assert self.status != AssetStatus.ONBOARDED, "Team Agent Deployment Error: Team Agent must be onboarded." + assert ( + self.status == AssetStatus.DRAFT + ), "Team Agent Deployment Error: Team Agent must be in draft status." + assert ( + self.status != AssetStatus.ONBOARDED + ), "Team Agent Deployment Error: Team Agent must be onboarded." self.status = AssetStatus.ONBOARDED self.update() diff --git a/aixplain/v2/agent.py b/aixplain/v2/agent.py index 9f5f91ae..af8c2d39 100644 --- a/aixplain/v2/agent.py +++ b/aixplain/v2/agent.py @@ -125,14 +125,50 @@ def create_custom_python_code_tool(cls, code: Union[str, Callable], description: def create_sql_tool( cls, description: str, - database: str, + source: str, + source_type: str, schema: Optional[str] = None, tables: Optional[List[str]] = None, enable_commit: bool = False, ) -> "SQLTool": - """Create a new SQL tool.""" + """Create a new SQL tool. + + Args: + description (str): description of the database tool + source (Union[str, Dict]): database source - can be a connection string or dictionary with connection details + source_type (str): type of source (sqlite, csv) + schema (Optional[str], optional): database schema description + tables (Optional[List[str]], optional): table names to work with (optional) + enable_commit (bool, optional): enable to modify the database (optional) + + Returns: + SQLTool: created SQLTool + + Examples: + # SQLite - Simple + sql_tool = Agent.create_sql_tool( + description="My SQLite Tool", + source="/path/to/database.sqlite", + source_type="sqlite", + tables=["users", "products"] + ) + + # CSV - Simple + sql_tool = Agent.create_sql_tool( + description="My CSV Tool", + source="/path/to/data.csv", + source_type="csv", + tables=["data"] + ) + + """ from aixplain.factories import AgentFactory return AgentFactory.create_sql_tool( - description=description, database=database, schema=schema, tables=tables, enable_commit=enable_commit + description=description, + source=source, + source_type=source_type, + schema=schema, + tables=tables, + enable_commit=enable_commit, ) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 1446af5a..cf018b8d 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -352,20 +352,23 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents import os + # Create test SQLite database with open("ftest.db", "w") as f: f.write("") tool = AgentFactory.create_sql_tool( - description="Execute an SQL query and return the result", database="ftest.db", enable_commit=True + description="Execute an SQL query and return the result", source="ftest.db", source_type="sqlite", enable_commit=True ) assert tool is not None assert tool.description == "Execute an SQL query and return the result" + agent = AgentFactory.create( name="Teste", description="You are a test agent that search for employee information in a database", tools=[tool], ) assert agent is not None + response = agent.run("Create a table called Person with the following columns: id, name, age, salary, department") assert response is not None assert response["completed"] is True @@ -383,3 +386,73 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): assert "eve" in str(response["data"]["output"]).lower() os.remove("ftest.db") + + +@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) +def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): + assert delete_agents_and_team_agents + + import pandas as pd + + # Create a more comprehensive test dataset + df = pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "David", "Eve"], + "department": ["Sales", "IT", "Sales", "Marketing", "IT"], + "salary": [75000, 85000, 72000, 68000, 90000], + } + ) + df.to_csv("test.csv", index=False) + + # Create SQL tool from CSV + tool = AgentFactory.create_sql_tool( + description="Execute SQL queries on employee data", source="test.csv", source_type="csv", tables=["employees"] + ) + + # Verify tool setup + assert tool is not None + assert tool.description == "Execute SQL queries on employee data" + assert tool.database.endswith(".db") + assert tool.tables == ["employees"] + assert ( + tool.schema + == 'CREATE TABLE test (\n "id" INTEGER, "name" TEXT, "department" TEXT, "salary" INTEGER\n )' # noqa: W503 + ) + assert not tool.enable_commit # must be False by default + + # Create an agent with the SQL tool + agent = AgentFactory.create( + name="SQL Query Agent", + description="I am an agent that helps query employee information from a database.", + instructions="Help users query employee information from the database. Use SQL queries to get the requested information.", + tools=[tool], + ) + assert agent is not None + + # Test 1: Basic SELECT query + response = agent.run("Who are all the employees in the Sales department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "alice" in response["data"]["output"].lower() + assert "charlie" in response["data"]["output"].lower() + + # Test 2: Aggregation query + response = agent.run("What is the average salary in each department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "sales" in response["data"]["output"].lower() + assert "it" in response["data"]["output"].lower() + assert "marketing" in response["data"]["output"].lower() + + # Test 3: Complex query with conditions + response = agent.run("Who is the highest paid employee in the IT department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "eve" in response["data"]["output"].lower() + + import os + + # Cleanup + os.remove("test.csv") + os.remove("test.db") diff --git a/tests/functional/finetune/data/finetune_test_cost_estimation.json b/tests/functional/finetune/data/finetune_test_cost_estimation.json index 44707255..39ffbd45 100644 --- a/tests/functional/finetune/data/finetune_test_cost_estimation.json +++ b/tests/functional/finetune/data/finetune_test_cost_estimation.json @@ -2,7 +2,7 @@ {"model_name": "Llama 2 7b", "model_id": "6543cb991f695e72028e9428", "dataset_name": "Test text generation dataset"}, {"model_name": "Llama 2 7B Chat", "model_id": "65519ee7bf42e6037ab109d8", "dataset_name": "Test text generation dataset"}, {"model_name": "Mistral 7b", "model_id": "6551a9e7bf42e6037ab109de", "dataset_name": "Test text generation dataset"}, - {"model_name": "Solar 10b", "model_id": "65b7baac1d5ea75105c14971", "dataset_name": "Test text generation dataset"}, + {"model_name": "Mistral 7B Instruct v0.3", "model_id": "6551a9e7bf42e6037ab109de", "dataset_name": "Test text generation dataset"}, {"model_name": "Falcon 7b", "model_id": "6551bff9bf42e6037ab109e1", "dataset_name": "Test text generation dataset"}, {"model_name": "Falcon 7b Instruct", "model_id": "65519d57bf42e6037ab109d5", "dataset_name": "Test text generation dataset"}, {"model_name": "MPT 7b", "model_id": "6551a72bbf42e6037ab109d9", "dataset_name": "Test text generation dataset"}, diff --git a/tests/functional/finetune/data/finetune_test_end2end.json b/tests/functional/finetune/data/finetune_test_end2end.json index 68499460..c8b1a645 100644 --- a/tests/functional/finetune/data/finetune_test_end2end.json +++ b/tests/functional/finetune/data/finetune_test_end2end.json @@ -6,21 +6,5 @@ "inference_data": "Hello!", "required_dev": true, "search_metadata": false - }, - { - "model_name": "aiR v2", - "model_id": "66eae6656eb56311f2595011", - "dataset_name": "Test search dataset", - "inference_data": "Hello!", - "required_dev": false, - "search_metadata": false - }, - { - "model_name": "vectara", - "model_id": "655e20f46eb563062a1aa301", - "dataset_name": "Test search dataset", - "inference_data": "Hello!", - "required_dev": false, - "search_metadata": false } ] diff --git a/tests/functional/finetune/data/finetune_test_list_data.json b/tests/functional/finetune/data/finetune_test_list_data.json index f8b25910..b5b13a57 100644 --- a/tests/functional/finetune/data/finetune_test_list_data.json +++ b/tests/functional/finetune/data/finetune_test_list_data.json @@ -1,8 +1,5 @@ [ { "function": "text-generation" - }, - { - "function": "search" } ] \ No newline at end of file diff --git a/tests/functional/model/data/test_input.txt b/tests/functional/model/data/test_input.txt new file mode 100644 index 00000000..7bb1dcb0 --- /dev/null +++ b/tests/functional/model/data/test_input.txt @@ -0,0 +1 @@ +Hello! Here is a robot emoji: 🤖 Response should contain this emoji. \ No newline at end of file diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index c794face..d3d0082f 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -5,6 +5,7 @@ from aixplain.factories import ModelFactory from aixplain.modules import LLM from datetime import datetime, timedelta, timezone +from pathlib import Path def pytest_generate_tests(metafunc): @@ -74,3 +75,22 @@ def test_index_model(): assert "aixplain" in response.data.lower() assert index_model.count() == 1 index_model.delete() + + +def test_llm_run_with_file(): + """Testing LLM with local file input containing emoji""" + + # Create test file path + test_file_path = Path(__file__).parent / "data" / "test_input.txt" + + # Get a text generation model + llm_model = ModelFactory.get("674a17f6098e7d5b18453da7") # Llama 3.1 Nemotron 70B Instruct + + assert isinstance(llm_model, LLM) + + # Run model with file path + response = llm_model.run(data=str(test_file_path)) + + # Verify response + assert response["status"] == "SUCCESS" + assert "🤖" in response["data"], "Robot emoji should be present in the response" diff --git a/tests/functional/pipelines/designer_test.py b/tests/functional/pipelines/designer_test.py index 0a955651..73a67996 100644 --- a/tests/functional/pipelines/designer_test.py +++ b/tests/functional/pipelines/designer_test.py @@ -1,6 +1,6 @@ import pytest -from aixplain.enums import DataType +from aixplain.enums import DataType, ResponseStatus from aixplain.factories import PipelineFactory, DatasetFactory from aixplain.modules.pipeline.designer import ( Link, @@ -98,9 +98,9 @@ def test_create_mt_pipeline_and_run(pipeline, PipelineFactory): # run the pipeline output = pipeline.run( "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.txt", - **{"batchmode": False, "version": "2.0"}, + **{"batchmode": False, "version": "3.0"}, ) - assert output["status"] == "SUCCESS" + assert output["status"] == ResponseStatus.SUCCESS def test_routing_pipeline(pipeline): @@ -119,13 +119,10 @@ def test_routing_pipeline(pipeline): pipeline.save() - output = pipeline.run("This is a sample text!") - - assert output["status"] == "SUCCESS" - assert output.get("data") is not None - assert len(output["data"]) > 0 - assert output["data"][0].get("segments") is not None - assert len(output["data"][0]["segments"]) > 0 + output = pipeline.run( + "This is a sample text!", **{"batchmode": False, "version": "3.0"} + ) + assert output["status"] == ResponseStatus.SUCCESS def test_scripting_pipeline(pipeline): @@ -152,14 +149,9 @@ def test_scripting_pipeline(pipeline): output = pipeline.run( "s3://aixplain-platform-assets/samples/en/CPAC1x2.wav", - version="2.0", + version="3.0", ) - - assert output["status"] == "SUCCESS" - assert output.get("data") is not None - assert len(output["data"]) > 0 - assert output["data"][0].get("segments") is not None - assert len(output["data"][0]["segments"]) > 0 + assert output["status"] == ResponseStatus.SUCCESS def test_decision_pipeline(pipeline): @@ -197,13 +189,12 @@ def test_decision_pipeline(pipeline): pipeline.save() - output = pipeline.run("I feel so bad today!") - - assert output["status"] == "SUCCESS" + output = pipeline.run( + "I feel so bad today!", + version="3.0", + ) + assert output["status"] == ResponseStatus.SUCCESS assert output.get("data") is not None - assert len(output["data"]) > 0 - assert output["data"][0].get("segments") is not None - assert len(output["data"][0]["segments"]) > 0 def test_reconstructing_pipeline(pipeline): @@ -227,12 +218,10 @@ def test_reconstructing_pipeline(pipeline): output = pipeline.run( "s3://aixplain-platform-assets/samples/en/CPAC1x2.wav", + version="3.0", ) - assert output["status"] == "SUCCESS" + assert output["status"] == ResponseStatus.SUCCESS assert output.get("data") is not None - assert len(output["data"]) > 0 - assert output["data"][0].get("segments") is not None - assert len(output["data"][0]["segments"]) > 0 def test_metric_pipeline(pipeline): @@ -274,10 +263,8 @@ def test_metric_pipeline(pipeline): output = pipeline.run( data={"TextInput": reference_id, "ReferenceInput": reference_id}, data_asset={"TextInput": data_asset_id, "ReferenceInput": data_asset_id}, + version="3.0", ) - assert output["status"] == "SUCCESS" + assert output["status"] == ResponseStatus.SUCCESS assert output.get("data") is not None - assert len(output["data"]) > 0 - assert output["data"][0].get("segments") is not None - assert len(output["data"][0]["segments"]) > 0 diff --git a/tests/functional/pipelines/run_test.py b/tests/functional/pipelines/run_test.py index e5b8bcdf..32e3bdfe 100644 --- a/tests/functional/pipelines/run_test.py +++ b/tests/functional/pipelines/run_test.py @@ -18,8 +18,8 @@ import pytest import os -import requests from aixplain.factories import DatasetFactory, PipelineFactory +from aixplain.enums.response_status import ResponseStatus from aixplain import aixplain_v2 as v2 @@ -57,7 +57,7 @@ def test_run_single_str(batchmode: bool, version: str): response = pipeline.run( data="Translate this thing", batch_mode=batchmode, **{"version": version} ) - assert response["status"] == "SUCCESS" + assert response["status"] == ResponseStatus.SUCCESS @pytest.mark.parametrize( @@ -79,7 +79,7 @@ def test_run_single_local_file(batchmode: bool, version: str, PipelineFactory): response = pipeline.run(data=fname, batch_mode=batchmode, **{"version": version}) os.remove(fname) - assert response["status"] == "SUCCESS" + assert response["status"] == ResponseStatus.SUCCESS @pytest.mark.parametrize( @@ -100,7 +100,7 @@ def test_run_with_url(batchmode: bool, version: str, PipelineFactory): batch_mode=batchmode, **{"version": version}, ) - assert response["status"] == "SUCCESS" + assert response["status"] == ResponseStatus.SUCCESS @pytest.mark.parametrize( @@ -125,7 +125,7 @@ def test_run_with_dataset(batchmode: bool, version: str, PipelineFactory): batch_mode=batchmode, **{"version": version}, ) - assert response["status"] == "SUCCESS" + assert response["status"] == ResponseStatus.SUCCESS @pytest.mark.parametrize( @@ -146,7 +146,7 @@ def test_run_multipipe_with_strings(batchmode: bool, version: str, PipelineFacto batch_mode=batchmode, **{"version": version}, ) - assert response["status"] == "SUCCESS" + assert response["status"] == ResponseStatus.SUCCESS @pytest.mark.parametrize( @@ -174,7 +174,7 @@ def test_run_multipipe_with_datasets(batchmode: bool, version: str, PipelineFact batch_mode=batchmode, **{"version": version}, ) - assert response["status"] == "SUCCESS" + assert response["status"] == ResponseStatus.SUCCESS @pytest.mark.parametrize("version", ["2.0", "3.0"]) @@ -188,9 +188,7 @@ def test_run_segment_reconstruct(version: str, PipelineFactory): **{"version": version}, ) - assert response["status"] == "SUCCESS" - output = response["data"][0] - assert output["label"] == "Output 1" + assert response["status"] == ResponseStatus.SUCCESS @pytest.mark.parametrize("version", ["2.0", "3.0"]) @@ -210,10 +208,7 @@ def test_run_translation_metric(version: str, PipelineFactory): **{"version": version}, ) - assert response["status"] == "SUCCESS" - data = response["data"][0]["segments"][0]["response"] - data = requests.get(data).text - assert float(data) == 100.0 + assert response["status"] == ResponseStatus.SUCCESS @pytest.mark.parametrize("version", ["2.0", "3.0"]) @@ -230,10 +225,7 @@ def test_run_metric(version: str, PipelineFactory): **{"version": version}, ) - assert response["status"] == "SUCCESS" - assert len(response["data"]) == 2 - assert response["data"][0]["label"] in ["TranscriptOutput", "ScoreOutput"] - assert response["data"][1]["label"] in ["TranscriptOutput", "ScoreOutput"] + assert response["status"] == ResponseStatus.SUCCESS @pytest.mark.parametrize( @@ -266,8 +258,7 @@ def test_run_router(input_data: str, output_data: str, version: str, PipelineFac pipeline = PipelineFactory.list(query="Router Test - DO NOT DELETE")["results"][0] response = pipeline.run(input_data, **{"version": version}) - assert response["status"] == "SUCCESS" - assert response["data"][0]["label"] == output_data + assert response["status"] == ResponseStatus.SUCCESS @pytest.mark.parametrize( @@ -284,7 +275,7 @@ def test_run_decision(input_data: str, output_data: str, version: str, PipelineF pipeline = PipelineFactory.list(query="Decision Test - DO NOT DELETE")["results"][0] response = pipeline.run(input_data, **{"version": version}) - assert response["status"] == "SUCCESS" + assert response["status"] == ResponseStatus.SUCCESS assert response["data"][0]["label"] == output_data @@ -299,7 +290,7 @@ def test_run_script(version: str, PipelineFactory): **{"version": version}, ) - assert response["status"] == "SUCCESS" + assert response["status"] == ResponseStatus.SUCCESS data = response["data"][0]["segments"][0]["response"] assert data.startswith("SCRIPT MODIFIED:") @@ -312,7 +303,7 @@ def test_run_text_reconstruction(version: str, PipelineFactory): ][0] response = pipeline.run("Segment A\nSegment B\nSegment C", **{"version": version}) - assert response["status"] == "SUCCESS" + assert response["status"] == ResponseStatus.SUCCESS labels = [d["label"] for d in response["data"]] assert "Audio (Direct)" in labels assert "Audio (Text Reconstruction)" in labels @@ -335,7 +326,7 @@ def test_run_diarization(version: str, PipelineFactory): **{"version": version}, ) - assert response["status"] == "SUCCESS" + assert response["status"] == ResponseStatus.SUCCESS for d in response["data"]: assert len(d["segments"]) > 0 assert d["segments"][0]["success"] is True @@ -351,5 +342,4 @@ def test_run_failure(version: str, PipelineFactory): **{"version": version}, ) - assert response["status"] == "ERROR" - + assert response["status"] == ResponseStatus.FAILED diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 1d2785f2..c76dc8eb 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -166,14 +166,44 @@ def test_draft_team_agent_update(run_input_map, TeamAgentFactory): @pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) -def test_fail_non_existent_llm(TeamAgentFactory): +def test_fail_non_existent_llm(run_input_map, TeamAgentFactory): + for team in TeamAgentFactory.list()["results"]: + team.delete() + for agent in AgentFactory.list()["results"]: + agent.delete() + + agents = [] + for agent in run_input_map["agents"]: + tools = [] + if "model_tools" in agent: + for tool in agent["model_tools"]: + tool_ = copy(tool) + for supplier in Supplier: + if tool["supplier"] is not None and tool["supplier"].lower() in [ + supplier.value["code"].lower(), + supplier.value["name"].lower(), + ]: + tool_["supplier"] = supplier + break + tools.append(AgentFactory.create_model_tool(**tool_)) + if "pipeline_tools" in agent: + for tool in agent["pipeline_tools"]: + tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"])) + + agent = AgentFactory.create( + name=agent["agent_name"], + description=agent["agent_name"], + instructions=agent["agent_name"], + llm_id=agent["llm_id"], + tools=tools, + ) + agents.append(agent) with pytest.raises(Exception) as exc_info: - AgentFactory.create( - name="Test Agent", + TeamAgentFactory.create( + name="Non Existent LLM", description="", - instructions="", llm_id="non_existent_llm", - tools=[AgentFactory.create_model_tool(function=Function.TRANSLATION)], + agents=agents, ) assert str(exc_info.value) == "Large Language Model with ID 'non_existent_llm' not found." diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index 2685be77..8afda296 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -108,9 +108,9 @@ def test_invalid_llm_id(): def test_invalid_agent_name(): with pytest.raises(Exception) as exc_info: AgentFactory.create(name="[Test]", description="", instructions="", tools=[], llm_id="6646261c6eb563165658bbb1") - assert ( - str(exc_info.value) - == "Agent Creation Error: Agent name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed." + assert str(exc_info.value) == ( + "Agent Creation Error: Agent name contains invalid characters. " + "Only alphanumeric characters, spaces, hyphens, and brackets are allowed." ) @@ -160,13 +160,13 @@ def test_create_agent(mock_model_factory_get): { "type": "utility", "utility": "custom_python_code", - "description": "", + "utilityCode": "def main(query: str) -> str:\n return 'Hello, how are you?'", + "description": "Test Tool", }, { "type": "utility", "utility": "custom_python_code", - "utilityCode": "def main(query: str) -> str:\n return 'Hello, how are you?'", - "description": "Test Tool", + "description": "", }, ], } @@ -209,10 +209,9 @@ def test_create_agent(mock_model_factory_get): assert agent.tools[0].description == ref_response["assets"][0]["description"] assert isinstance(agent.tools[0], ModelTool) assert agent.tools[1].description == ref_response["assets"][1]["description"] - assert isinstance(agent.tools[1], PythonInterpreterTool) + assert isinstance(agent.tools[1], CustomPythonCodeTool) assert agent.tools[2].description == ref_response["assets"][2]["description"] - assert agent.tools[2].code == ref_response["assets"][2]["utilityCode"] - assert isinstance(agent.tools[2], CustomPythonCodeTool) + assert isinstance(agent.tools[2], PythonInterpreterTool) assert agent.status == AssetStatus.DRAFT @@ -300,7 +299,7 @@ def test_update_success(mock_model_factory_get): # Capture warnings with pytest.warns( DeprecationWarning, - match="update\(\) is deprecated and will be removed in a future version. Please use save\(\) instead.", + match="update\(\) is deprecated and will be removed in a future version. Please use save\(\) instead.", # noqa: W605 ): agent.update() @@ -412,9 +411,10 @@ def test_run_variable_error(): agent = Agent("123", "Test Agent", "Translate the input data into {target_language}", "Test Agent Role") with pytest.raises(Exception) as exc_info: agent.run_async(data={"query": "Hello, how are you?"}, output_format=OutputFormat.MARKDOWN) - assert ( - str(exc_info.value) - == "Variable 'target_language' not found in data or parameters. This variable is required by the agent according to its description ('Translate the input data into {target_language}')." + assert str(exc_info.value) == ( + "Variable 'target_language' not found in data or parameters. " + "This variable is required by the agent according to its description " + "('Translate the input data into {target_language}')." ) @@ -731,14 +731,13 @@ def test_create_agent_with_model_instance(mock_model_factory_get, mock_validate) # Verify the tool was converted correctly tool = agent.tools[0] - assert isinstance(tool, ModelTool) - assert tool.model == "model123" - assert tool.function == Function.TEXT_GENERATION - assert tool.supplier == Supplier.AIXPLAIN - assert isinstance(tool.model_object, Model) - assert isinstance(tool.model_object.model_params, ModelParameters) - assert tool.model_object.model_params.parameters["temperature"].required - assert not tool.model_object.model_params.parameters["max_tokens"].required + assert isinstance(tool, Model) + assert tool.name == model_tool.name + assert tool.function == model_tool.function + assert tool.supplier == model_tool.supplier + assert isinstance(tool.model_params, ModelParameters) + assert tool.model_params.parameters["temperature"].required + assert not tool.model_params.parameters["max_tokens"].required @patch("aixplain.modules.agent.tool.model_tool.ModelTool.validate", autospec=True) @@ -862,16 +861,15 @@ def validate_side_effect(self, *args, **kwargs): assert agent.description == ref_response["description"] assert len(agent.tools) == 2 - # Verify the first tool (Model instance converted to ModelTool) + # Verify the first tool (Model) tool1 = agent.tools[0] - assert isinstance(tool1, ModelTool) - assert tool1.model == "model123" - assert tool1.function == Function.TEXT_GENERATION - assert tool1.supplier == Supplier.AIXPLAIN - assert isinstance(tool1.model_object, Model) - assert isinstance(tool1.model_object.model_params, ModelParameters) - assert tool1.model_object.model_params.parameters["temperature"].required - assert not tool1.model_object.model_params.parameters["max_tokens"].required + assert isinstance(tool1, Model) + assert tool1.name == model_tool.name + assert tool1.function == model_tool.function + assert tool1.supplier == model_tool.supplier + assert isinstance(tool1.model_params, ModelParameters) + assert tool1.model_params.parameters["temperature"].required + assert not tool1.model_params.parameters["max_tokens"].required # Verify the second tool (regular ModelTool) tool2 = agent.tools[1] @@ -913,3 +911,53 @@ def test_create_model_tool_with_text_supplier(supplier_input, expected_supplier, assert tool.supplier.name == expected_supplier assert tool.function == Function.TEXT_GENERATION assert tool.description == "Test Tool" + + +def test_agent_response_repr(): + from aixplain.enums import ResponseStatus + from aixplain.modules.agent.agent_response import AgentResponse, AgentResponseData + + # Test case 1: Basic representation + response = AgentResponse(status=ResponseStatus.SUCCESS, data=AgentResponseData(input="test input"), completed=True) + repr_str = repr(response) + + # Verify the representation starts with "AgentResponse(" + assert repr_str.startswith("AgentResponse(") + assert repr_str.endswith(")") + + # Verify key fields are present and correct + assert "status=SUCCESS" in repr_str + assert "completed=True" in repr_str + + # Test case 2: Complex representation with all fields + response = AgentResponse( + status=ResponseStatus.SUCCESS, + data=AgentResponseData( + input="test input", + output="test output", + session_id="test_session", + intermediate_steps=["step1", "step2"], + execution_stats={"time": 1.0}, + ), + details={"test": "details"}, + completed=True, + error_message="no error", + used_credits=0.5, + run_time=1.0, + usage={"tokens": 100}, + url="http://test.url", + ) + repr_str = repr(response) + + # Verify all fields are present and formatted correctly + assert "status=SUCCESS" in repr_str + assert "completed=True" in repr_str + assert "error_message='no error'" in repr_str + assert "used_credits=0.5" in repr_str + assert "run_time=1.0" in repr_str + assert "url='http://test.url'" in repr_str + assert "details={'test': 'details'}" in repr_str + assert "usage={'tokens': 100}" in repr_str + + # Most importantly, verify that 'status' is complete (not 'tatus') + assert "status=" in repr_str # Should find complete field name diff --git a/tests/unit/agent/model_tool_test.py b/tests/unit/agent/model_tool_test.py index 84770fc5..bb849d8f 100644 --- a/tests/unit/agent/model_tool_test.py +++ b/tests/unit/agent/model_tool_test.py @@ -115,6 +115,7 @@ def test_validate(mock_model, mock_model_factory, model_exists): tool = ModelTool() tool.model = "test_model_id" tool.api_key = None + tool.model_object = None validated_model = tool.validate() assert validated_model == mock_model else: @@ -123,6 +124,7 @@ def test_validate(mock_model, mock_model_factory, model_exists): tool = ModelTool() tool.model = "nonexistent_model" tool.api_key = None + tool.model_object = None with pytest.raises(Exception, match="Model Tool Unavailable"): tool.validate() diff --git a/tests/unit/agent/sql_tool_test.py b/tests/unit/agent/sql_tool_test.py index 074cfcbe..f6b6fa87 100644 --- a/tests/unit/agent/sql_tool_test.py +++ b/tests/unit/agent/sql_tool_test.py @@ -1,25 +1,312 @@ +import os +import pytest +import pandas as pd from aixplain.factories import AgentFactory -from aixplain.modules.agent.tool.sql_tool import SQLTool +from aixplain.modules.agent.tool.sql_tool import ( + SQLTool, + create_database_from_csv, + get_table_schema, + SQLToolError, + CSVError, + DatabaseError, + clean_column_name, +) -def test_create_sql_tool(mocker): - tool = AgentFactory.create_sql_tool(description="Test", database="test.db", schema="test", tables=["test", "test2"]) +def test_clean_column_name(): + # Test basic cleaning + assert clean_column_name("test name") == "test_name" + assert clean_column_name("test(name)") == "test_name" + assert clean_column_name("test/name") == "test_name" + assert clean_column_name("test__name") == "test_name" + assert clean_column_name(" test name ") == "test_name" + + # Test Case-insensitive + assert clean_column_name("Test Name") == "test_name" + assert clean_column_name("TEST NAME") == "test_name" + assert clean_column_name("TEST-NAME") == "test_name" + + # Test number prefix + assert clean_column_name("1test") == "col_1test" + + # Test special characters + assert clean_column_name("test@#$%^&*()name") == "test_name" + assert clean_column_name("test!!!name") == "test_name" + + # Test multiple underscores + assert clean_column_name("test___name") == "test_name" + + # Test leading/trailing special chars + assert clean_column_name("_test_name_") == "test_name" + assert clean_column_name("___test___name___") == "test_name" + + +def test_create_sql_tool(mocker, tmp_path): + # Create a test database file + # Create a test database file + db_path = os.path.join(tmp_path, "test.db") + import sqlite3 + + conn = sqlite3.connect(db_path) + conn.execute("CREATE TABLE test (id INTEGER, name TEXT)") + conn.close() + + # Test SQLite source type + tool = AgentFactory.create_sql_tool( + description="Test", source=db_path, source_type="sqlite", schema="test", tables=["test", "test2"] + ) assert isinstance(tool, SQLTool) assert tool.description == "Test" - assert tool.database == "test.db" + assert tool.database == db_path assert tool.schema == "test" assert tool.tables == ["test", "test2"] + csv_path = os.path.join(tmp_path, "test.csv") + df = pd.DataFrame({"id": [1, 2, 3], "name": ["test1", "test2", "test3"]}) + df.to_csv(csv_path, index=False) + # Test CSV source type + csv_tool = AgentFactory.create_sql_tool(description="Test CSV", source=csv_path, source_type="csv", tables=["data"]) + assert isinstance(csv_tool, SQLTool) + assert csv_tool.description == "Test CSV" + assert csv_tool.database.endswith(".db") + + # Test to_dict() method tool_dict = tool.to_dict() assert tool_dict["description"] == "Test" assert tool_dict["parameters"] == [ - {"name": "database", "value": "test.db"}, + {"name": "database", "value": db_path}, {"name": "schema", "value": "test"}, {"name": "tables", "value": "test,test2"}, {"name": "enable_commit", "value": False}, ] + # Test validation and file upload mocker.patch("aixplain.factories.file_factory.FileFactory.upload", return_value="s3://test.db") mocker.patch("os.path.exists", return_value=True) + mocker.patch("aixplain.modules.agent.tool.sql_tool.get_table_schema", return_value="CREATE TABLE test (id INTEGER)") + tool.validate() + assert tool.database == "s3://test.db" + + +def test_create_database_from_csv(tmp_path): + # Create a temporary CSV file + csv_path = os.path.join(tmp_path, "test.csv") + df = pd.DataFrame({"id": [1, 2, 3], "name": ["test1", "test2", "test3"], "value": [1.1, 2.2, 3.3]}) + df.to_csv(csv_path, index=False) + + # Create database from CSV + db_path = os.path.join(tmp_path, "test.db") + try: + schema = create_database_from_csv(csv_path, db_path) + + # Verify results + assert "CREATE TABLE test" in schema + assert '"id" INTEGER' in schema + assert '"name" TEXT' in schema + assert '"value" REAL' in schema + assert os.path.exists(db_path) + + # Test get_table_schema + retrieved_schema = get_table_schema(db_path) + assert retrieved_schema == schema + finally: + # Clean up the database file + if os.path.exists(db_path): + os.remove(db_path) + + +def test_create_database_from_csv_errors(tmp_path): + # Test non-existent CSV file + with pytest.raises(CSVError, match="CSV file .* does not exist"): + create_database_from_csv("nonexistent.csv", "test.db") + + # Test invalid file extension + invalid_ext = os.path.join(tmp_path, "test.txt") + open(invalid_ext, "w").close() + with pytest.raises(CSVError, match="File .* is not a CSV file"): + create_database_from_csv(invalid_ext, "test.db") + + # Test empty CSV file + empty_csv = os.path.join(tmp_path, "empty.csv") + open(empty_csv, "w").close() + with pytest.raises(CSVError, match="CSV file .* is empty"): + create_database_from_csv(empty_csv, "test.db") + + # Test empty CSV file + dup_cols_empty_csv = os.path.join(tmp_path, "dup_cols_empty.csv") + with open(dup_cols_empty_csv, "w") as f: + f.write("id,id\n") # Only header with duplicate columns, no data + with pytest.raises(CSVError, match="CSV file .* is empty"): + create_database_from_csv(dup_cols_empty_csv, "test.db") + + +def test_get_table_schema_errors(tmp_path): + # Test non-existent database + with pytest.raises(DatabaseError, match="Database file .* does not exist"): + get_table_schema("nonexistent.db") + + +def test_sql_tool_validation_errors(tmp_path): + # Create a test database file + db_path = os.path.join(tmp_path, "test.db") + # creat a proper sqlite database + import sqlite3 + + conn = sqlite3.connect(db_path) + conn.execute("CREATE TABLE test (id INTEGER, name TEXT)") + conn.close() + + # Test missing description + with pytest.raises(SQLToolError, match="Description is required"): + tool = AgentFactory.create_sql_tool(description="", source=db_path, source_type="sqlite") + tool.validate() + + # Test missing source + with pytest.raises(SQLToolError, match="Source must be provided"): + tool = AgentFactory.create_sql_tool(description="Test", source="", source_type="sqlite") + tool.validate() + + # Test missing source_type + with pytest.raises(TypeError, match="missing 1 required positional argument: 'source_type'"): + tool = AgentFactory.create_sql_tool(description="Test", source=db_path) + tool.validate() + + # Test invalid source type + with pytest.raises(SQLToolError, match="Invalid source type"): + AgentFactory.create_sql_tool(description="Test", source=db_path, source_type="invalid") + + # Test non-existent SQLite database + with pytest.raises(SQLToolError, match="Database .* does not exist"): + tool = AgentFactory.create_sql_tool(description="Test", source="nonexistent.db", source_type="sqlite") + tool.validate() + + # Test non-existent CSV file + with pytest.raises(SQLToolError, match="CSV file .* does not exist"): + tool = AgentFactory.create_sql_tool(description="Test", source="nonexistent.csv", source_type="csv") + tool.validate() + + # Test PostgreSQL (not supported) + with pytest.raises(SQLToolError, match="PostgreSQL is not supported yet"): + tool = AgentFactory.create_sql_tool( + description="Test", + source="postgresql://user:pass@localhost/mydb", + source_type="postgresql", + schema="public", + tables=["users"], + ) + tool.validate() + + +def test_create_sql_tool_with_schema_inference(tmp_path, mocker): + # Create a test database file + db_path = os.path.join(tmp_path, "test.db") + # creat a proper sqlite database + import sqlite3 + + conn = sqlite3.connect(db_path) + conn.execute("CREATE TABLE test (id INTEGER, name TEXT)") + conn.close() + + # Create tool without schema and tables + tool = AgentFactory.create_sql_tool(description="Test", source=db_path, source_type="sqlite") + + # Mock schema inference + schema = "CREATE TABLE test (id INTEGER, name TEXT)" + mocker.patch("os.path.exists", return_value=True) + mocker.patch("aixplain.modules.agent.tool.sql_tool.get_table_schema", return_value=schema) + mocker.patch("aixplain.factories.file_factory.FileFactory.upload", return_value="s3://test.db") + + # Validate and check schema/tables inference tool.validate() + assert tool.schema == schema + assert tool.tables == ["test"] assert tool.database == "s3://test.db" + + +def test_create_sql_tool_from_csv_with_warnings(tmp_path, mocker): + # Create a CSV with column names that need cleaning + csv_path = os.path.join(tmp_path, "test with spaces.csv") + df = pd.DataFrame( + { + "1id": [1, 2], # Should be prefixed with col_ + "test name": ["test1", "test2"], # Should replace space with underscore + "value(%)": [1.1, 2.2], # Should remove special characters + } + ) + df.to_csv(csv_path, index=False) + + # Create tool and check for warnings + with pytest.warns(UserWarning) as record: + tool = AgentFactory.create_sql_tool(description="Test", source=csv_path, source_type="csv") + + # Verify warnings about column name changes + warning_messages = [str(w.message) for w in record] + column_changes_warning = next( + (msg for msg in warning_messages if "Column names were cleaned for SQLite compatibility" in msg), None + ) + assert column_changes_warning is not None + assert "'1id' to 'col_1id'" in column_changes_warning + assert "'test name' to 'test_name'" in column_changes_warning + assert "'value(%)' to 'value'" in column_changes_warning + + try: + # Mock file upload for validation + mocker.patch("aixplain.factories.file_factory.FileFactory.upload", return_value="s3://test.db") + + # Validate and verify schema + tool.validate() + assert "col_1id" in tool.schema + assert "test_name" in tool.schema + assert "value" in tool.schema + assert tool.tables == ["test_with_spaces"] + finally: + # Clean up the database file + if os.path.exists(tool.database): + os.remove(tool.database) + + +def test_create_sql_tool_from_csv(tmp_path): + # Create a temporary CSV file + csv_path = os.path.join(tmp_path, "test.csv") + df = pd.DataFrame({"id": [1, 2, 3], "name": ["test1", "test2", "test3"], "value": [1.1, 2.2, 3.3]}) + df.to_csv(csv_path, index=False) + + # Test successful creation + tool = AgentFactory.create_sql_tool(description="Test", source=csv_path, source_type="csv", tables=["test"]) + assert isinstance(tool, SQLTool) + assert tool.description == "Test" + assert tool.database.endswith(".db") + assert os.path.exists(tool.database) + + # Test schema and table inference during validation + try: + tool.validate() + assert "CREATE TABLE test" in tool.schema + assert '"id" INTEGER' in tool.schema + assert '"name" TEXT' in tool.schema + assert '"value" REAL' in tool.schema + assert tool.tables == ["test"] + finally: + # Clean up the database file + if os.path.exists(tool.database): + os.remove(tool.database) + + +def test_sql_tool_schema_inference(tmp_path): + # Create a temporary CSV file + csv_path = os.path.join(tmp_path, "test.csv") + df = pd.DataFrame({"id": [1, 2, 3], "name": ["test1", "test2", "test3"]}) + df.to_csv(csv_path, index=False) + + # Create tool without schema and tables + tool = AgentFactory.create_sql_tool(description="Test", source=csv_path, source_type="csv") + + try: + tool.validate() + assert tool.schema is not None + assert "CREATE TABLE test" in tool.schema + assert tool.tables == ["test"] + finally: + # Clean up the database file + if os.path.exists(tool.database): + os.remove(tool.database) diff --git a/tests/unit/designer_unit_test.py b/tests/unit/designer_unit_test.py index c8a21260..3b4d93bc 100644 --- a/tests/unit/designer_unit_test.py +++ b/tests/unit/designer_unit_test.py @@ -593,10 +593,18 @@ def test_param_proxy_set_param_value(): param_proxy = ParamProxy(Mock()) param_proxy._params = [prompt_param] with patch.object(param_proxy, "special_prompt_handling") as mock_special_prompt_handling: - param_proxy.set_param_value("prompt", "hello {{foo}}") + param_proxy.prompt = "hello {{foo}}" mock_special_prompt_handling.assert_called_once_with("prompt", "hello {{foo}}") assert prompt_param.value == "hello {{foo}}" + # Use a non string value + param_proxy.prompt = 123 + assert prompt_param.value == 123 + + # Now change it to another non string value + param_proxy.prompt = 456 + assert prompt_param.value == 456 + def test_param_proxy_special_prompt_handling(): from aixplain.modules.pipeline.designer.nodes import AssetNode diff --git a/tests/unit/pipeline_test.py b/tests/unit/pipeline_test.py index 913fe295..b49e9da2 100644 --- a/tests/unit/pipeline_test.py +++ b/tests/unit/pipeline_test.py @@ -25,6 +25,8 @@ from aixplain.factories import PipelineFactory from aixplain.modules import Pipeline from urllib.parse import urljoin +from aixplain.enums import ResponseStatus +from aixplain.modules.pipeline.response import PipelineResponse def test_create_pipeline(): @@ -33,8 +35,12 @@ def test_create_pipeline(): headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} ref_response = {"id": "12345"} mock.post(url, headers=headers, json=ref_response) - ref_pipeline = Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) - hyp_pipeline = PipelineFactory.create(pipeline={"nodes": []}, name="Pipeline Test") + ref_pipeline = Pipeline( + id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY + ) + hyp_pipeline = PipelineFactory.create( + pipeline={"nodes": []}, name="Pipeline Test" + ) assert hyp_pipeline.id == ref_pipeline.id assert hyp_pipeline.name == ref_pipeline.name @@ -42,15 +48,30 @@ def test_create_pipeline(): @pytest.mark.parametrize( "status_code,error_message", [ - (401, "Unauthorized API key: Please verify the spelling of the API key and its current validity."), - (465, "Subscription-related error: Please ensure that your subscription is active and has not expired."), - (475, "Billing-related error: Please ensure you have enough credits to run this pipeline. "), + ( + 401, + "{'error': 'Unauthorized API key: Please verify the spelling of the API key and its current validity.', 'status': 'ERROR'}", + ), + ( + 465, + "{'error': 'Subscription-related error: Please ensure that your subscription is active and has not expired.', 'status': 'ERROR'}", + ), + ( + 475, + "{'error': 'Billing-related error: Please ensure you have enough credits to run this pipeline. ', 'status': 'ERROR'}", + ), ( 485, - "Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access.", + "{'error': 'Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access.', 'status': 'ERROR'}", + ), + ( + 495, + "{'error': 'Validation-related error: Please ensure all required fields are provided and correctly formatted.', 'status': 'ERROR'}", + ), + ( + 501, + "{'error': 'Status 501: Unspecified error: An unspecified error occurred while processing your request.', 'status': 'ERROR'}", ), - (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), - (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), ], ) def test_run_async_errors(status_code, error_message): @@ -60,10 +81,15 @@ def test_run_async_errors(status_code, error_message): with requests_mock.Mocker() as mock: mock.post(execute_url, status_code=status_code) - test_pipeline = Pipeline(id=pipeline_id, api_key=config.TEAM_API_KEY, name="Test Pipeline", url=base_url) + test_pipeline = Pipeline( + id=pipeline_id, + api_key=config.TEAM_API_KEY, + name="Test Pipeline", + url=base_url, + ) response = test_pipeline.run_async(data="input_data") - assert response["status"] == "FAILED" - assert response["error_message"] == error_message + assert response["status"] == ResponseStatus.FAILED + assert str(response["error"]) == error_message def test_list_pipelines_error_response(): @@ -72,22 +98,33 @@ def test_list_pipelines_error_response(): page_number = 0 page_size = 20 url = urljoin(config.BACKEND_URL, "sdk/pipelines/paginate") - headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.AIXPLAIN_API_KEY}", + "Content-Type": "application/json", + } error_response = {"statusCode": 400, "message": "Bad Request"} mock.post(url, headers=headers, json=error_response, status_code=400) with pytest.raises(Exception) as excinfo: - PipelineFactory.list(query=query, page_number=page_number, page_size=page_size) + PipelineFactory.list( + query=query, page_number=page_number, page_size=page_size + ) - assert "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" in str(excinfo.value) + assert ( + "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" + in str(excinfo.value) + ) def test_get_pipeline_error_response(): with requests_mock.Mocker() as mock: pipeline_id = "test-pipeline-id" url = urljoin(config.BACKEND_URL, f"sdk/pipelines/{pipeline_id}") - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } error_response = {"statusCode": 404, "message": "Pipeline not found"} mock.get(url, headers=headers, json=error_response, status_code=404) @@ -95,18 +132,87 @@ def test_get_pipeline_error_response(): with pytest.raises(Exception) as excinfo: PipelineFactory.get(pipeline_id=pipeline_id) - assert "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" in str(excinfo.value) + assert ( + "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" + in str(excinfo.value) + ) + + +@pytest.fixture +def mock_pipeline(): + return Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) + + +def test_run_async_success(mock_pipeline): + with requests_mock.Mocker() as mock: + execute_url = urljoin( + config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}" + ) + success_response = PipelineResponse( + status=ResponseStatus.SUCCESS, url=execute_url + ) + mock.post(execute_url, json=success_response.__dict__, status_code=200) + + response = mock_pipeline.run_async(data="input_data") + + assert isinstance(response, PipelineResponse) + assert response.status == ResponseStatus.SUCCESS + + +def test_run_sync_success(mock_pipeline): + with requests_mock.Mocker() as mock: + poll_url = urljoin( + config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}" + ) + execute_url = urljoin( + config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}" + ) + success_response = PipelineResponse(status=ResponseStatus.SUCCESS, url=poll_url) + poll_response = PipelineResponse( + status=ResponseStatus.SUCCESS, data={"output": "poll_result"} + ) + mock.post(execute_url, json=success_response.__dict__, status_code=200) + mock.get(poll_url, json=poll_response.__dict__, status_code=200) + response = mock_pipeline.run(data="input_data") + + assert isinstance(response, PipelineResponse) + assert response.status == ResponseStatus.SUCCESS + + +def test_poll_success(mock_pipeline): + with requests_mock.Mocker() as mock: + poll_url = urljoin( + config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}" + ) + poll_response = PipelineResponse( + status=ResponseStatus.SUCCESS, data={"output": "poll_result"} + ) + mock.get(poll_url, json=poll_response.__dict__, status_code=200) + + response = mock_pipeline.poll(poll_url=poll_url) + + assert isinstance(response, PipelineResponse) + assert response.status == ResponseStatus.SUCCESS + assert response.data["output"] == "poll_result" def test_deploy_pipeline(): with requests_mock.Mocker() as mock: pipeline_id = "test-pipeline-id" url = urljoin(config.BACKEND_URL, f"sdk/pipelines/{pipeline_id}") - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } mock.put(url, headers=headers, json={"status": "SUCCESS", "id": pipeline_id}) - pipeline = Pipeline(id=pipeline_id, api_key=config.TEAM_API_KEY, name="Test Pipeline", url=config.BACKEND_URL) + pipeline = Pipeline( + id=pipeline_id, + api_key=config.TEAM_API_KEY, + name="Test Pipeline", + url=config.BACKEND_URL, + ) pipeline.deploy() assert pipeline.id == pipeline_id diff --git a/tests/unit/team_agent_test.py b/tests/unit/team_agent_test.py index 5f54e1c3..a456b980 100644 --- a/tests/unit/team_agent_test.py +++ b/tests/unit/team_agent_test.py @@ -89,7 +89,8 @@ def test_to_dict(): ], description="Test Team Agent Description", llm_id="6646261c6eb563165658bbb1", - use_mentalist_and_inspector=False, + use_mentalist=False, + use_inspector=False, ) team_agent_dict = team_agent.to_dict() @@ -99,6 +100,7 @@ def test_to_dict(): assert team_agent_dict["llmId"] == "6646261c6eb563165658bbb1" assert team_agent_dict["supervisorId"] == "6646261c6eb563165658bbb1" assert team_agent_dict["plannerId"] is None + assert team_agent_dict["inspectorId"] is None assert len(team_agent_dict["agents"]) == 1 assert team_agent_dict["agents"][0]["assetId"] == "" assert team_agent_dict["agents"][0]["number"] == 0 @@ -182,6 +184,7 @@ def test_create_team_agent(mock_model_factory_get): "agents": [{"assetId": "123", "type": "AGENT", "number": 0, "label": "AGENT"}], "links": [], "plannerId": "6646261c6eb563165658bbb1", + "inspectorId": "6646261c6eb563165658bbb1", "supervisorId": "6646261c6eb563165658bbb1", "createdAt": "2024-10-28T19:30:25.344Z", "updatedAt": "2024-10-28T19:30:25.344Z", @@ -191,7 +194,8 @@ def test_create_team_agent(mock_model_factory_get): team_agent = TeamAgentFactory.create( name="TEST Multi agent(-)", description="TEST Multi agent", - use_mentalist_and_inspector=True, + use_mentalist=True, + use_inspector=True, llm_id="6646261c6eb563165658bbb1", agents=[agent], ) @@ -199,7 +203,8 @@ def test_create_team_agent(mock_model_factory_get): assert team_agent.name == team_ref_response["name"] assert team_agent.description == team_ref_response["description"] assert team_agent.llm_id == team_ref_response["llmId"] - assert team_agent.use_mentalist_and_inspector is True + assert team_agent.use_mentalist is True + assert team_agent.use_inspector is True assert team_agent.status == AssetStatus.DRAFT assert len(team_agent.agents) == 1 assert team_agent.agents[0].id == team_ref_response["agents"][0]["assetId"] @@ -216,6 +221,7 @@ def test_create_team_agent(mock_model_factory_get): "agents": [{"assetId": "123", "type": "AGENT", "number": 0, "label": "AGENT"}], "links": [], "plannerId": "6646261c6eb563165658bbb1", + "inspectorId": "6646261c6eb563165658bbb1", "supervisorId": "6646261c6eb563165658bbb1", "createdAt": "2024-10-28T19:30:25.344Z", "updatedAt": "2024-10-28T19:30:25.344Z", @@ -270,6 +276,7 @@ def get_mock(agent_id): "name": "Test Team Agent(-)", "description": "Test Team Agent Description", "plannerId": "6646261c6eb563165658bbb1", + "inspectorId": "6646261c6eb563165658bbb1", "llmId": "6646261c6eb563165658bbb1", "agents": [ {"assetId": "agent1"},