diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c7ad1c08..2a567d6f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,22 +1,4 @@ repos: - - repo: local - hooks: - - id: pytest-check - name: pytest-check - entry: coverage run --source=. -m pytest tests/unit - language: python - pass_filenames: false - types: [python] - always_run: true - - - repo: https://github.com/psf/black - rev: 25.1.0 - hooks: - - id: black - language_version: python3 - args: # arguments to configure black - - --line-length=128 - - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 # Use the latest version hooks: @@ -25,16 +7,19 @@ repos: - id: check-merge-conflict - id: check-added-large-files - - repo: https://github.com/pycqa/flake8 - rev: 7.2.0 - hooks: - - id: flake8 - args: # arguments to configure flake8 - - --ignore=E402,E501,E203,W503 - - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.12.12 hooks: - id: ruff args: [--fix] - id: ruff-format + + - repo: local + hooks: + - id: pytest-check + name: pytest-check + entry: coverage run --source=. -m pytest tests/unit + language: python + pass_filenames: false + types: [python] + always_run: true diff --git a/aixplain/exceptions/__init__.py b/aixplain/exceptions/__init__.py index 1d430d83..56e0f2c8 100644 --- a/aixplain/exceptions/__init__.py +++ b/aixplain/exceptions/__init__.py @@ -1,5 +1,4 @@ -""" -Error message registry for aiXplain SDK. +"""Error message registry for aiXplain SDK. This module maintains a centralized registry of error messages used throughout the aiXplain ecosystem. It allows developers to look up existing error messages and reuse them instead of creating new ones. @@ -9,6 +8,7 @@ AixplainBaseException, AuthenticationError, ValidationError, + AlreadyDeployedError, ResourceError, BillingError, SupplierError, @@ -33,12 +33,11 @@ def get_error_from_status_code(status_code: int, error_details: str = None) -> AixplainBaseException: - """ - Map HTTP status codes to appropriate exception types. + """Map HTTP status codes to appropriate exception types. Args: status_code (int): The HTTP status code to map. - default_message (str, optional): The default message to use if no specific message is available. + error_details (str, optional): Additional error details to include in the message. Returns: AixplainBaseException: An exception of the appropriate type. @@ -126,5 +125,6 @@ def get_error_from_status_code(status_code: int, error_details: str = None) -> A # Catch-all for other client/server errors category = "Client" if 400 <= status_code < 500 else "Server" return InternalError( - message=f"Unspecified {category} Error (Status {status_code}) {error_details}".strip(), status_code=status_code + message=f"Unspecified {category} Error (Status {status_code}) {error_details}".strip(), + status_code=status_code, ) diff --git a/aixplain/exceptions/types.py b/aixplain/exceptions/types.py index f38f1463..a565eb91 100644 --- a/aixplain/exceptions/types.py +++ b/aixplain/exceptions/types.py @@ -1,3 +1,5 @@ +"""Exception types and error handling for the aiXplain SDK.""" + from enum import Enum from typing import Optional, Dict, Any @@ -121,6 +123,17 @@ def __init__( retry_recommended: bool = False, error_code: Optional[ErrorCode] = None, ): + """Initialize the base exception with structured error information. + + Args: + message: Error message describing the issue. + category: Category of the error (default: UNKNOWN). + severity: Severity level of the error (default: ERROR). + status_code: HTTP status code if applicable. + details: Additional error context and details. + retry_recommended: Whether retrying the operation might succeed. + error_code: Standardized error code for the exception. + """ self.message = message self.category = category self.severity = severity @@ -163,6 +176,12 @@ class AuthenticationError(AixplainBaseException): """Raised when authentication fails.""" def __init__(self, message: str, **kwargs): + """Initialize authentication error. + + Args: + message: Error message describing the authentication issue. + **kwargs: Additional keyword arguments passed to parent class. + """ super().__init__( message=message, category=ErrorCategory.AUTHENTICATION, @@ -177,6 +196,12 @@ class ValidationError(AixplainBaseException): """Raised when input validation fails.""" def __init__(self, message: str, **kwargs): + """Initialize validation error. + + Args: + message: Error message describing the validation issue. + **kwargs: Additional keyword arguments passed to parent class. + """ super().__init__( message=message, category=ErrorCategory.VALIDATION, @@ -187,10 +212,34 @@ def __init__(self, message: str, **kwargs): ) +class AlreadyDeployedError(AixplainBaseException): + """Raised when attempting to deploy an asset that is already deployed.""" + + def __init__(self, message: str, **kwargs): + """Initialize already deployed error. + + Args: + message: Error message describing the deployment state conflict. + **kwargs: Additional keyword arguments passed to parent class. + """ + super().__init__( + message=message, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_VAL_ERROR, + **kwargs, + ) + + class ResourceError(AixplainBaseException): """Raised when a resource is unavailable.""" def __init__(self, message: str, **kwargs): + """Initialize resource error. + + Args: + message: Error message describing the resource issue. + **kwargs: Additional keyword arguments passed to parent class. + """ super().__init__( message=message, category=ErrorCategory.RESOURCE, @@ -205,6 +254,12 @@ class BillingError(AixplainBaseException): """Raised when there are billing issues.""" def __init__(self, message: str, **kwargs): + """Initialize billing error. + + Args: + message: Error message describing the billing issue. + **kwargs: Additional keyword arguments passed to parent class. + """ super().__init__( message=message, category=ErrorCategory.BILLING, @@ -219,6 +274,12 @@ class SupplierError(AixplainBaseException): """Raised when there are issues with external suppliers.""" def __init__(self, message: str, **kwargs): + """Initialize supplier error. + + Args: + message: Error message describing the supplier issue. + **kwargs: Additional keyword arguments passed to parent class. + """ super().__init__( message=message, category=ErrorCategory.SUPPLIER, @@ -233,6 +294,12 @@ class NetworkError(AixplainBaseException): """Raised when there are network connectivity issues.""" def __init__(self, message: str, **kwargs): + """Initialize network error. + + Args: + message: Error message describing the network issue. + **kwargs: Additional keyword arguments passed to parent class. + """ super().__init__( message=message, category=ErrorCategory.NETWORK, @@ -247,6 +314,12 @@ class ServiceError(AixplainBaseException): """Raised when a service is unavailable.""" def __init__(self, message: str, **kwargs): + """Initialize service error. + + Args: + message: Error message describing the service issue. + **kwargs: Additional keyword arguments passed to parent class. + """ super().__init__( message=message, category=ErrorCategory.SERVICE, @@ -261,6 +334,12 @@ class InternalError(AixplainBaseException): """Raised when there is an internal system error.""" def __init__(self, message: str, **kwargs): + """Initialize internal error. + + Args: + message: Error message describing the internal issue. + **kwargs: Additional keyword arguments passed to parent class. + """ # Server errors (5xx) should generally be retryable status_code = kwargs.get("status_code") retry_recommended = kwargs.pop("retry_recommended", False) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 190555ce..090b4c0e 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -1,6 +1,8 @@ -__author__ = "aiXplain" +"""Agent module for aiXplain SDK. + +This module provides the Agent class and related functionality for creating and managing +AI agents that can execute tasks using various tools and models. -""" Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +22,8 @@ Description: Agentification Class """ + +__author__ = "aiXplain" import json import logging import re @@ -33,7 +37,7 @@ from aixplain.modules.model import Model from aixplain.modules.agent.agent_task import WorkflowTask, AgentTask from aixplain.modules.agent.output_format import OutputFormat -from aixplain.modules.agent.tool import Tool +from aixplain.modules.agent.tool import Tool, DeployableTool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.agent.utils import process_variables, validate_history @@ -48,7 +52,7 @@ import warnings -class Agent(Model, DeployableMixin[Tool]): +class Agent(Model, DeployableMixin[Union[Tool, DeployableTool]]): """An advanced AI system that performs tasks using specialized tools from the aiXplain marketplace. This class represents an AI agent that can understand natural language instructions, @@ -124,6 +128,8 @@ def __init__( Defaults to AssetStatus.DRAFT. tasks (List[AgentTask], optional): List of tasks the Agent can perform. Defaults to empty list. + workflow_tasks (List[WorkflowTask], optional): List of workflow tasks + the Agent can execute. Defaults to empty list. output_format (OutputFormat, optional): default output format for agent responses. Defaults to OutputFormat.TEXT. expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. **additional_info: Additional configuration parameters. @@ -144,7 +150,8 @@ def __init__( self.status = status if tasks: warnings.warn( - "The 'tasks' parameter is deprecated and will be removed in a future version. " "Use 'workflow_tasks' instead.", + "The 'tasks' parameter is deprecated and will be removed in a future version. " + "Use 'workflow_tasks' instead.", DeprecationWarning, stacklevel=2, ) @@ -171,9 +178,9 @@ def _validate(self) -> None: from aixplain.utils.llm_utils import get_llm_instance # validate name - assert ( - re.match(r"^[a-zA-Z0-9 \-\(\)]*$", self.name) is not None - ), "Agent Creation Error: Agent name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed." + assert re.match(r"^[a-zA-Z0-9 \-\(\)]*$", self.name) is not None, ( + "Agent Creation Error: Agent name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed." + ) llm = get_llm_instance(self.llm_id, api_key=self.api_key, use_cache=True) @@ -233,6 +240,14 @@ def validate(self, raise_exception: bool = False) -> bool: return self.is_valid def generate_session_id(self, history: list = None) -> str: + """Generate a unique session ID for agent conversations. + + Args: + history (list, optional): Previous conversation history. Defaults to None. + + Returns: + str: A unique session identifier based on timestamp and random components. + """ if history: validate_history(history) timestamp = datetime.now().strftime("%Y%m%d%H%M%S") @@ -307,6 +322,7 @@ def run( max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10. output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization. expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. + Returns: Dict: parsed output from model """ @@ -406,10 +422,10 @@ def run_async( expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. output_format (ResponseFormat, optional): response format. Defaults to TEXT. evolve (Union[Dict[str, Any], EvolveParam, None], optional): evolve the agent configuration. Can be a dictionary, EvolveParam instance, or None. + Returns: dict: polling URL in response """ - if session_id is not None and history is not None: raise ValueError("Provide either `session_id` or `history`, not both.") @@ -434,7 +450,9 @@ def run_async( assert data is not None or query is not None, "Either 'data' or 'query' must be provided." if data is not None: if isinstance(data, dict): - assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided." + assert "query" in data and data["query"] is not None, ( + "When providing a dictionary, 'query' must be provided." + ) query = data.get("query") if session_id is None: session_id = data.get("session_id") @@ -447,7 +465,9 @@ def run_async( # process content inputs if content is not None: - assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text." + assert FileFactory.check_storage_type(query) == StorageType.TEXT, ( + "When providing 'content', query must be text." + ) if isinstance(content, list): assert len(content) <= 3, "The maximum number of content inputs is 3." @@ -511,6 +531,11 @@ def run_async( ) def to_dict(self) -> Dict: + """Convert the Agent instance to a dictionary representation. + + Returns: + Dict: Dictionary containing the agent's configuration and metadata. + """ from aixplain.factories.agent_factory.utils import build_tool_payload return { @@ -674,9 +699,9 @@ def delete(self) -> None: "referencing it." ) else: - message = f"Agent Deletion Error (HTTP {r.status_code}): " f"{error_message}." + message = f"Agent Deletion Error (HTTP {r.status_code}): {error_message}." except ValueError: - message = f"Agent Deletion Error (HTTP {r.status_code}): " "There was an error in deleting the agent." + message = f"Agent Deletion Error (HTTP {r.status_code}): There was an error in deleting the agent." logging.error(message) raise Exception(message) @@ -701,7 +726,7 @@ def update(self) -> None: stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " "Please use save() instead.", + "update() is deprecated and will be removed in a future version. Please use save() instead.", DeprecationWarning, stacklevel=2, ) diff --git a/aixplain/modules/agent/tool/__init__.py b/aixplain/modules/agent/tool/__init__.py index b83c9704..76bbd802 100644 --- a/aixplain/modules/agent/tool/__init__.py +++ b/aixplain/modules/agent/tool/__init__.py @@ -1,6 +1,8 @@ -__author__ = "aiXplain" +"""Agent tool module for aiXplain SDK. + +This module provides tool classes and functionality for agents to interact with +various services, models, and data sources. -""" Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +22,8 @@ Description: Agentification Class """ + +__author__ = "aiXplain" from abc import ABC from typing import Optional, Text from aixplain.utils import config @@ -83,13 +87,10 @@ def validate(self): """ raise NotImplementedError - def deploy(self) -> None: - """Deploys the tool to make it available for use. - This method should handle any necessary setup or deployment steps - required to make the tool operational. +class DeployableTool(Tool): + """Tool that can be deployed.""" - Raises: - NotImplementedError: This is an abstract method that must be implemented by subclasses. - """ + def deploy(self) -> None: + """Deploy the tool.""" raise NotImplementedError diff --git a/aixplain/modules/agent/tool/custom_python_code_tool.py b/aixplain/modules/agent/tool/custom_python_code_tool.py index 0a027c04..1fa32541 100644 --- a/aixplain/modules/agent/tool/custom_python_code_tool.py +++ b/aixplain/modules/agent/tool/custom_python_code_tool.py @@ -1,6 +1,8 @@ -__author__ = "aiXplain" +"""Custom Python code tool for aiXplain SDK agents. + +This module provides a tool that allows agents to execute custom Python code +in a controlled environment. -""" Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +23,8 @@ Agentification Class """ +__author__ = "aiXplain" + from typing import Text, Union, Callable, Optional from aixplain.modules.agent.tool import Tool import logging @@ -110,9 +114,9 @@ def validate(self): if name and name.strip() != "": self.name = name - assert ( - self.description and self.description.strip() != "" - ), "Custom Python Code Tool Error: Tool description is required" + assert self.description and self.description.strip() != "", ( + "Custom Python Code Tool Error: Tool description is required" + ) assert self.code and self.code.strip() != "", "Custom Python Code Tool Error: Code is required" assert self.name and self.name.strip() != "", "Custom Python Code Tool Error: Name is required" assert self.status in [ @@ -127,11 +131,3 @@ def __repr__(self) -> Text: Text: A string in the format "CustomPythonCodeTool(name=)". """ return f"CustomPythonCodeTool(name={self.name})" - - def deploy(self): - """Deploy the custom Python code tool. - - This is a placeholder method as custom Python code tools are automatically - deployed when created. - """ - pass diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index ab3ae1af..b6cd060a 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -1,6 +1,8 @@ -__author__ = "aiXplain" +"""Model tool for aiXplain SDK agents. + +This module provides a tool that allows agents to interact with AI models +and execute model-based tasks. -""" Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +22,9 @@ Description: Agentification Class """ + +__author__ = "aiXplain" + from typing import Optional, Union, Text, Dict, List from aixplain.enums import AssetStatus, Function, Supplier @@ -162,8 +167,8 @@ def to_dict(self) -> Dict: } def validate(self) -> None: - """ - Validates the tool. + """Validates the tool. + Notes: - Checks if the tool has a function or model. - If the function is a string, it converts it to a Function enum. @@ -174,16 +179,16 @@ def validate(self) -> None: """ from aixplain.enums import FunctionInputOutput - assert ( - self.function is not None or self.model is not None - ), "Agent Creation Error: Either function or model must be provided when instantiating a tool." + assert self.function is not None or self.model is not None, ( + "Agent Creation Error: Either function or model must be provided when instantiating a tool." + ) if self.function is not None: if isinstance(self.function, str): self.function = Function(self.function) - assert ( - self.function is None or self.function is not Function.UTILITIES or self.model is not None - ), "Agent Creation Error: Utility function must be used with an associated model." + assert self.function is None or self.function is not Function.UTILITIES or self.model is not None, ( + "Agent Creation Error: Utility function must be used with an associated model." + ) try: if isinstance(self.supplier, dict): @@ -196,7 +201,9 @@ def validate(self) -> None: try: self.model = self._get_model() except Exception: - raise Exception(f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it.") + raise Exception( + f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it." + ) self.function = self.model.function if isinstance(self.model.supplier, Supplier): self.supplier = self.model.supplier @@ -213,7 +220,6 @@ def validate(self) -> None: self.parameters = self.validate_parameters(self.parameters) self.name = self.name if self.name else set_tool_name(self.function, self.supplier, self.model) - def get_parameters(self) -> Dict: """Get the tool's parameters, either from explicit settings or the model object. @@ -224,9 +230,9 @@ def get_parameters(self) -> Dict: # If parameters were not explicitly provided, get them from the model if ( self.parameters is None - and self.model_object is not None # noqa: W503 - and hasattr(self.model_object, "model_params") # noqa: W503 - and self.model_object.model_params is not None # noqa: W503 + and self.model_object is not None + and hasattr(self.model_object, "model_params") + and self.model_object.model_params is not None ): return self.model_object.model_params.to_list() return self.parameters @@ -265,11 +271,11 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) # Get default parameters if none provided if ( self.model_object is not None - and hasattr(self.model_object, "model_params") # noqa: W503 - and self.model_object.model_params is not None # noqa: W503 + and hasattr(self.model_object, "model_params") + and self.model_object.model_params is not None ): return self.model_object.model_params.to_list() - + elif self.function is not None: function_params = self.function.get_parameters() if function_params is not None: @@ -280,8 +286,8 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) expected_params = None if ( self.model_object is not None - and hasattr(self.model_object, "model_params") # noqa: W503 - and self.model_object.model_params is not None # noqa: W503 + and hasattr(self.model_object, "model_params") + and self.model_object.model_params is not None ): expected_params = self.model_object.model_params elif self.function is not None: @@ -300,7 +306,9 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) invalid_params = received_param_names - expected_param_names if invalid_params: - raise ValueError(f"Invalid parameters provided: {invalid_params}. Expected parameters are: {expected_param_names}") + raise ValueError( + f"Invalid parameters provided: {invalid_params}. Expected parameters are: {expected_param_names}" + ) return received_parameters @@ -314,11 +322,3 @@ def __repr__(self) -> Text: supplier_str = self.supplier.value if self.supplier is not None else None model_str = self.model.id if self.model is not None else None return f"ModelTool(name={self.name}, function={self.function}, supplier={supplier_str}, model={model_str})" - - def deploy(self): - """Deploy the model tool. - - This is a placeholder method as model tools are managed through the aiXplain platform - and don't require explicit deployment. - """ - pass diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index 4ed2bd8e..ab847275 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -1,6 +1,8 @@ -__author__ = "aiXplain" +"""Pipeline tool for aiXplain SDK agents. + +This module provides a tool that allows agents to execute AI pipelines +and chain multiple AI operations together. -""" Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +22,9 @@ Description: Agentification Class """ + +__author__ = "aiXplain" + from typing import Text, Union, Optional from aixplain.modules.agent.tool import Tool @@ -123,11 +128,3 @@ def validate(self): if self.name.strip() == "": self.name = pipeline_obj.name self.status = pipeline_obj.status - - def deploy(self): - """Deploy the pipeline tool. - - This is a placeholder method as pipeline tools are managed through the aiXplain platform - and don't require explicit deployment. - """ - pass diff --git a/aixplain/modules/agent/tool/python_interpreter_tool.py b/aixplain/modules/agent/tool/python_interpreter_tool.py index 7be02bed..181bdf72 100644 --- a/aixplain/modules/agent/tool/python_interpreter_tool.py +++ b/aixplain/modules/agent/tool/python_interpreter_tool.py @@ -1,6 +1,8 @@ -__author__ = "aiXplain" +"""Python interpreter tool for aiXplain SDK agents. + +This module provides a tool that allows agents to execute Python code +using an interpreter in a controlled environment. -""" Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +23,8 @@ Agentification Class """ +__author__ = "aiXplain" + from aixplain.modules.agent.tool import Tool from aixplain.enums import AssetStatus @@ -83,11 +87,3 @@ def __repr__(self) -> Text: Text: A string in the format "PythonInterpreterTool()". """ return "PythonInterpreterTool()" - - def deploy(self): - """Deploy the Python interpreter tool. - - This is a placeholder method as the Python interpreter tool is automatically - available and doesn't require explicit deployment. - """ - pass diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index 331f4539..e124cf76 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -1,6 +1,8 @@ -__author__ = "aiXplain" +"""SQL tool for aiXplain SDK agents. + +This module provides a tool that allows agents to execute SQL queries +against databases and CSV files. -""" Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +22,9 @@ Description: Agentification Class """ + +__author__ = "aiXplain" + import os import warnings import validators @@ -28,23 +33,23 @@ from typing import Text, Optional, Dict, List, Union import sqlite3 from aixplain.enums import AssetStatus -from aixplain.modules.agent.tool import Tool +from aixplain.modules.agent.tool import DeployableTool class SQLToolError(Exception): - """Base exception for SQL Tool errors""" + """Base exception for SQL Tool errors.""" pass class CSVError(SQLToolError): - """Exception for CSV-related errors""" + """Exception for CSV-related errors.""" pass class DatabaseError(SQLToolError): - """Exception for database-related errors""" + """Exception for database-related errors.""" pass @@ -221,7 +226,9 @@ def create_database_from_csv(csv_path: str, database_path: str, table_name: str # Clean column names and track changes original_columns = df.columns.tolist() cleaned_columns = [clean_column_name(col) for col in original_columns] - changed_columns = [(orig, cleaned) for orig, cleaned in zip(original_columns, cleaned_columns) if orig != cleaned] + changed_columns = [ + (orig, cleaned) for orig, cleaned in zip(original_columns, cleaned_columns) if orig != cleaned + ] if changed_columns: changes = ", ".join([f"'{orig}' to '{cleaned}'" for orig, cleaned in changed_columns]) @@ -333,7 +340,7 @@ def get_table_names_from_schema(schema: str) -> List[str]: return table_names -class SQLTool(Tool): +class SQLTool(DeployableTool): """A tool for executing SQL commands in an SQLite database. This tool provides an interface for interacting with SQLite databases, including @@ -381,7 +388,6 @@ def __init__( Raises: SQLToolError: If required parameters are missing or invalid. """ - super().__init__(name=name, description=description, **additional_info) self.database = database @@ -440,9 +446,9 @@ def validate(self): # Handle database validation if not ( str(self.database).startswith("s3://") - or str(self.database).startswith("http://") # noqa: W503 - or str(self.database).startswith("https://") # noqa: W503 - or validators.url(self.database) # noqa: W503 + or str(self.database).startswith("http://") + or str(self.database).startswith("https://") + or validators.url(self.database) ): if not os.path.exists(self.database): raise SQLToolError(f"Database '{self.database}' does not exist") diff --git a/aixplain/modules/mixins.py b/aixplain/modules/mixins.py index 8b3b3723..20aa2983 100644 --- a/aixplain/modules/mixins.py +++ b/aixplain/modules/mixins.py @@ -1,4 +1,5 @@ -""" +"""Mixins for common functionality across different asset types. + Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -48,6 +49,7 @@ def _validate_deployment_readiness(self) -> None: items (Optional[List[T]], optional): List of items to validate (e.g. tools for Agent, agents for TeamAgent) Raises: + AlreadyDeployedError: If the asset is already deployed ValueError: If the asset is not ready to be deployed """ asset_type = self.__class__.__name__ @@ -64,6 +66,7 @@ def deploy(self) -> None: Classes that need special deployment handling should override this method. Raises: + AlreadyDeployedError: If the asset is already deployed ValueError: If the asset is not ready to be deployed """ self._validate_deployment_readiness() @@ -71,24 +74,27 @@ def deploy(self) -> None: try: # Deploy tools if present if hasattr(self, "tools"): - [tool.deploy() for tool in self.tools] for tool in self.tools: - try: - tool.deploy() - except AlreadyDeployedError: - pass - except Exception as e: - raise Exception(f"Error deploying tool {tool.name}: {e}") from e + if hasattr(tool, "deploy"): + try: + tool.deploy() + except AlreadyDeployedError: + # Skip tools that are already deployed + pass + except Exception as e: + raise Exception(f"Error deploying tool {tool.name}: {e}") from e # Deploy agents if present (for TeamAgent) if hasattr(self, "agents"): for agent in self.agents: - try: - agent.deploy() - except AlreadyDeployedError: - pass - except Exception as e: - raise Exception(f"Error deploying agent {agent.name}: {e}") from e + if hasattr(agent, "deploy"): + try: + agent.deploy() + except AlreadyDeployedError: + # Skip agents that are already deployed + pass + except Exception as e: + raise Exception(f"Error deploying agent {agent.name}: {e}") from e self.status = AssetStatus.ONBOARDED self.update() diff --git a/aixplain/modules/model/connection.py b/aixplain/modules/model/connection.py index 94369f1f..148ff5f7 100644 --- a/aixplain/modules/model/connection.py +++ b/aixplain/modules/model/connection.py @@ -1,3 +1,23 @@ +"""Copyright 2025 The aiXplain SDK authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Ahmet Gündüz +Date: September 10th 2025 +Description: + Connection Tool Class. +""" + from aixplain.enums import Function, Supplier, FunctionType, ResponseStatus from aixplain.modules.model import Model from aixplain.utils import config @@ -52,6 +72,15 @@ def __repr__(self): class ConnectionTool(Model): + """A class representing a connection tool. + + This class defines the structure of a connection tool with its actions and action scope. + + Attributes: + actions (List[ConnectAction]): A list of available actions for this connection. + action_scope (Optional[List[ConnectAction]]): The scope of actions for this connection. + """ + actions: List[ConnectAction] action_scope: Optional[List[ConnectAction]] = None @@ -84,9 +113,9 @@ def __init__( function_type (FunctionType, optional): Type of the Connection. Defaults to FunctionType.CONNECTION. **additional_info: Any additional Connection info to be saved """ - assert ( - function_type == FunctionType.CONNECTION or function_type == FunctionType.MCP_CONNECTION - ), "Connection only supports connection function" + assert function_type == FunctionType.CONNECTION or function_type == FunctionType.MCP_CONNECTION, ( + "Connection only supports connection function" + ) super().__init__( id=id, name=name, @@ -188,9 +217,9 @@ def get_parameters(self) -> List[Dict]: Raises: AssertionError: If the action scope is not set or is empty. """ - assert ( - self.action_scope is not None and len(self.action_scope) > 0 - ), f"Please set the scope of actions for the connection '{self.id}'." + assert self.action_scope is not None and len(self.action_scope) > 0, ( + f"Please set the scope of actions for the connection '{self.id}'." + ) response = [ { "code": action.code, diff --git a/tests/functional/agent/agent_mcp_deploy_test.py b/tests/functional/agent/agent_mcp_deploy_test.py new file mode 100644 index 00000000..74c3bcf3 --- /dev/null +++ b/tests/functional/agent/agent_mcp_deploy_test.py @@ -0,0 +1,180 @@ +"""Tests for agent deployment functionality with MCP (Model Control Protocol) tools. + +This test verifies that agents can be created, deployed, and used with MCP tools, +including proper status management and cleanup. +""" + +import pytest +import logging +from uuid import uuid4 +from aixplain.factories import ToolFactory, AgentFactory +from aixplain.modules.model.integration import AuthenticationSchema +from aixplain.enums import AssetStatus +from aixplain.exceptions import AlreadyDeployedError +from tests.test_deletion_utils import safe_delete_all_agents_and_team_agents + + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="function") +def cleanup_agents(): + """Fixture to clean up agents before and after tests.""" + # Clean up before test + safe_delete_all_agents_and_team_agents() + + yield True + + # Clean up after test + safe_delete_all_agents_and_team_agents() + + +@pytest.fixture +def mcp_tool(): + """Create an MCP tool for testing.""" + tool = ToolFactory.create( + integration="686eb9cd26480723d0634d3e", # Remote MCP ID + name=f"Test Remote MCP {uuid4()}", + authentication_schema=AuthenticationSchema.API_KEY, + data={"url": "https://remote.mcpservers.org/fetch/mcp"}, + ) + + # Set allowed actions (using ... as in original script) + tool.allowed_actions = [...] + + # Filter actions to only include "fetch" action + tool.action_scope = [action for action in tool.actions if action.code == "fetch"] + + return tool + + +@pytest.fixture +def test_agent(cleanup_agents, mcp_tool): + """Create a test agent with MCP tool.""" + agent = AgentFactory.create( + name=f"Test Agent {uuid4()}", + description="This agent is used to scrape websites", + instructions="You are a helpful assistant that can scrape any given website", + tools=[mcp_tool], + llm="669a63646eb56306647e1091", + ) + return agent + + +def test_agent_creation_with_mcp_tool(test_agent, mcp_tool): + """Test that an agent can be created with an MCP tool.""" + assert test_agent is not None + assert test_agent.name.startswith("Test Agent") + assert test_agent.description == "This agent is used to scrape websites" + assert len(test_agent.tools) == 1 + assert test_agent.tools[0] == mcp_tool + assert test_agent.status == AssetStatus.DRAFT + + +def test_agent_run_before_deployment(test_agent): + """Test that an agent can run before being deployed.""" + response = test_agent.run("Give me information about the aixplain website") + + assert response is not None + assert hasattr(response, "data") + assert hasattr(response.data, "output") + + +def test_agent_deployment(test_agent): + """Test that an agent can be deployed successfully.""" + # Verify initial status is DRAFT + assert test_agent.status == AssetStatus.DRAFT + + # Deploy the agent + test_agent.deploy() + + # Verify status is now ONBOARDED + assert test_agent.status == AssetStatus.ONBOARDED + + +def test_agent_retrieval_after_deployment(test_agent): + """Test that a deployed agent can be retrieved and maintains its status.""" + # Deploy the agent first + test_agent.deploy() + agent_id = test_agent.id + + # Retrieve the agent by ID + retrieved_agent = AgentFactory.get(agent_id) + + assert retrieved_agent is not None + assert retrieved_agent.id == agent_id + assert retrieved_agent.status == AssetStatus.ONBOARDED + + +def test_deployed_agent_cannot_be_deployed_again(test_agent): + """Test that attempting to deploy an already deployed agent raises an error.""" + # Deploy the agent first + test_agent.deploy() + assert test_agent.status == AssetStatus.ONBOARDED + + # Attempt to deploy again should raise AlreadyDeployedError + with pytest.raises(AlreadyDeployedError, match="is already deployed"): + test_agent.deploy() + + +def test_deployed_agent_can_run(test_agent): + """Test that a deployed agent can still run queries.""" + # Deploy the agent first + test_agent.deploy() + + # Run a query on the deployed agent + response = test_agent.run("Give me information about the aixplain website") + + assert response is not None + assert hasattr(response, "data") + assert hasattr(response.data, "output") + + +def test_agent_lifecycle_end_to_end(cleanup_agents, mcp_tool): + """Test the complete agent lifecycle: create, run, deploy, retrieve, run, delete.""" + # Create agent + agent = AgentFactory.create( + name=f"Test Agent Lifecycle {uuid4()}", + description="This agent is used for lifecycle testing", + instructions="You are a helpful assistant that can scrape any given website", + tools=[mcp_tool], + llm="669a63646eb56306647e1091", + ) + + # Test initial state + assert agent.status == AssetStatus.DRAFT + + # Test run before deployment + response = agent.run("Give me information about the aixplain website") + assert response is not None + + # Deploy agent + agent.deploy() + assert agent.status == AssetStatus.ONBOARDED + + # Retrieve agent by ID + agent_id = agent.id + retrieved_agent = AgentFactory.get(agent_id) + assert retrieved_agent.status == AssetStatus.ONBOARDED + + # Test run after deployment + response = retrieved_agent.run("Give me information about the aixplain website") + assert response is not None + + # Clean up + retrieved_agent.delete() + + +def test_mcp_tool_properties(mcp_tool): + """Test that the MCP tool has the expected properties.""" + assert mcp_tool is not None + assert mcp_tool.name.startswith("MCP Server") + assert hasattr(mcp_tool, "actions") + assert hasattr(mcp_tool, "action_scope") + assert len(mcp_tool.action_scope) >= 0 # Should have filtered actions + + # Verify that action_scope only contains "fetch" actions + for action in mcp_tool.action_scope: + assert action.code == "fetch"