Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 10 additions & 25 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,4 @@
repos:
- repo: local
hooks:
- id: pytest-check
name: pytest-check
entry: coverage run --source=. -m pytest tests/unit
language: python
pass_filenames: false
types: [python]
always_run: true

- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: black
language_version: python3
args: # arguments to configure black
- --line-length=128

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0 # Use the latest version
hooks:
Expand All @@ -25,16 +7,19 @@ repos:
- id: check-merge-conflict
- id: check-added-large-files

- repo: https://github.com/pycqa/flake8
rev: 7.2.0
hooks:
- id: flake8
args: # arguments to configure flake8
- --ignore=E402,E501,E203,W503

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.12
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: local
hooks:
- id: pytest-check
name: pytest-check
entry: coverage run --source=. -m pytest tests/unit
language: python
pass_filenames: false
types: [python]
always_run: true
12 changes: 6 additions & 6 deletions aixplain/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Error message registry for aiXplain SDK.
"""Error message registry for aiXplain SDK.

This module maintains a centralized registry of error messages used throughout the aiXplain ecosystem.
It allows developers to look up existing error messages and reuse them instead of creating new ones.
Expand All @@ -9,6 +8,7 @@
AixplainBaseException,
AuthenticationError,
ValidationError,
AlreadyDeployedError,
ResourceError,
BillingError,
SupplierError,
Expand All @@ -33,12 +33,11 @@


def get_error_from_status_code(status_code: int, error_details: str = None) -> AixplainBaseException:
"""
Map HTTP status codes to appropriate exception types.
"""Map HTTP status codes to appropriate exception types.

Args:
status_code (int): The HTTP status code to map.
default_message (str, optional): The default message to use if no specific message is available.
error_details (str, optional): Additional error details to include in the message.

Returns:
AixplainBaseException: An exception of the appropriate type.
Expand Down Expand Up @@ -126,5 +125,6 @@ def get_error_from_status_code(status_code: int, error_details: str = None) -> A
# Catch-all for other client/server errors
category = "Client" if 400 <= status_code < 500 else "Server"
return InternalError(
message=f"Unspecified {category} Error (Status {status_code}) {error_details}".strip(), status_code=status_code
message=f"Unspecified {category} Error (Status {status_code}) {error_details}".strip(),
status_code=status_code,
)
79 changes: 79 additions & 0 deletions aixplain/exceptions/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Exception types and error handling for the aiXplain SDK."""

from enum import Enum
from typing import Optional, Dict, Any

Expand Down Expand Up @@ -121,6 +123,17 @@ def __init__(
retry_recommended: bool = False,
error_code: Optional[ErrorCode] = None,
):
"""Initialize the base exception with structured error information.

Args:
message: Error message describing the issue.
category: Category of the error (default: UNKNOWN).
severity: Severity level of the error (default: ERROR).
status_code: HTTP status code if applicable.
details: Additional error context and details.
retry_recommended: Whether retrying the operation might succeed.
error_code: Standardized error code for the exception.
"""
self.message = message
self.category = category
self.severity = severity
Expand Down Expand Up @@ -163,6 +176,12 @@ class AuthenticationError(AixplainBaseException):
"""Raised when authentication fails."""

def __init__(self, message: str, **kwargs):
"""Initialize authentication error.

Args:
message: Error message describing the authentication issue.
**kwargs: Additional keyword arguments passed to parent class.
"""
super().__init__(
message=message,
category=ErrorCategory.AUTHENTICATION,
Expand All @@ -177,6 +196,12 @@ class ValidationError(AixplainBaseException):
"""Raised when input validation fails."""

def __init__(self, message: str, **kwargs):
"""Initialize validation error.

Args:
message: Error message describing the validation issue.
**kwargs: Additional keyword arguments passed to parent class.
"""
super().__init__(
message=message,
category=ErrorCategory.VALIDATION,
Expand All @@ -187,10 +212,34 @@ def __init__(self, message: str, **kwargs):
)


class AlreadyDeployedError(AixplainBaseException):
"""Raised when attempting to deploy an asset that is already deployed."""

def __init__(self, message: str, **kwargs):
"""Initialize already deployed error.

Args:
message: Error message describing the deployment state conflict.
**kwargs: Additional keyword arguments passed to parent class.
"""
super().__init__(
message=message,
retry_recommended=kwargs.pop("retry_recommended", False),
error_code=ErrorCode.AX_VAL_ERROR,
**kwargs,
)


class ResourceError(AixplainBaseException):
"""Raised when a resource is unavailable."""

def __init__(self, message: str, **kwargs):
"""Initialize resource error.

Args:
message: Error message describing the resource issue.
**kwargs: Additional keyword arguments passed to parent class.
"""
super().__init__(
message=message,
category=ErrorCategory.RESOURCE,
Expand All @@ -205,6 +254,12 @@ class BillingError(AixplainBaseException):
"""Raised when there are billing issues."""

def __init__(self, message: str, **kwargs):
"""Initialize billing error.

Args:
message: Error message describing the billing issue.
**kwargs: Additional keyword arguments passed to parent class.
"""
super().__init__(
message=message,
category=ErrorCategory.BILLING,
Expand All @@ -219,6 +274,12 @@ class SupplierError(AixplainBaseException):
"""Raised when there are issues with external suppliers."""

def __init__(self, message: str, **kwargs):
"""Initialize supplier error.

Args:
message: Error message describing the supplier issue.
**kwargs: Additional keyword arguments passed to parent class.
"""
super().__init__(
message=message,
category=ErrorCategory.SUPPLIER,
Expand All @@ -233,6 +294,12 @@ class NetworkError(AixplainBaseException):
"""Raised when there are network connectivity issues."""

def __init__(self, message: str, **kwargs):
"""Initialize network error.

Args:
message: Error message describing the network issue.
**kwargs: Additional keyword arguments passed to parent class.
"""
super().__init__(
message=message,
category=ErrorCategory.NETWORK,
Expand All @@ -247,6 +314,12 @@ class ServiceError(AixplainBaseException):
"""Raised when a service is unavailable."""

def __init__(self, message: str, **kwargs):
"""Initialize service error.

Args:
message: Error message describing the service issue.
**kwargs: Additional keyword arguments passed to parent class.
"""
super().__init__(
message=message,
category=ErrorCategory.SERVICE,
Expand All @@ -261,6 +334,12 @@ class InternalError(AixplainBaseException):
"""Raised when there is an internal system error."""

def __init__(self, message: str, **kwargs):
"""Initialize internal error.

Args:
message: Error message describing the internal issue.
**kwargs: Additional keyword arguments passed to parent class.
"""
# Server errors (5xx) should generally be retryable
status_code = kwargs.get("status_code")
retry_recommended = kwargs.pop("retry_recommended", False)
Expand Down
53 changes: 39 additions & 14 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__author__ = "aiXplain"
"""Agent module for aiXplain SDK.

This module provides the Agent class and related functionality for creating and managing
AI agents that can execute tasks using various tools and models.

"""
Copyright 2024 The aiXplain SDK authors

Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -20,6 +22,8 @@
Description:
Agentification Class
"""

__author__ = "aiXplain"
import json
import logging
import re
Expand All @@ -33,7 +37,7 @@
from aixplain.modules.model import Model
from aixplain.modules.agent.agent_task import WorkflowTask, AgentTask
from aixplain.modules.agent.output_format import OutputFormat
from aixplain.modules.agent.tool import Tool
from aixplain.modules.agent.tool import Tool, DeployableTool
from aixplain.modules.agent.agent_response import AgentResponse
from aixplain.modules.agent.agent_response_data import AgentResponseData
from aixplain.modules.agent.utils import process_variables, validate_history
Expand All @@ -48,7 +52,7 @@
import warnings


class Agent(Model, DeployableMixin[Tool]):
class Agent(Model, DeployableMixin[Union[Tool, DeployableTool]]):
"""An advanced AI system that performs tasks using specialized tools from the aiXplain marketplace.

This class represents an AI agent that can understand natural language instructions,
Expand Down Expand Up @@ -124,6 +128,8 @@ def __init__(
Defaults to AssetStatus.DRAFT.
tasks (List[AgentTask], optional): List of tasks the Agent can perform.
Defaults to empty list.
workflow_tasks (List[WorkflowTask], optional): List of workflow tasks
the Agent can execute. Defaults to empty list.
output_format (OutputFormat, optional): default output format for agent responses. Defaults to OutputFormat.TEXT.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
**additional_info: Additional configuration parameters.
Expand All @@ -144,7 +150,8 @@ def __init__(
self.status = status
if tasks:
warnings.warn(
"The 'tasks' parameter is deprecated and will be removed in a future version. " "Use 'workflow_tasks' instead.",
"The 'tasks' parameter is deprecated and will be removed in a future version. "
"Use 'workflow_tasks' instead.",
DeprecationWarning,
stacklevel=2,
)
Expand All @@ -171,9 +178,9 @@ def _validate(self) -> None:
from aixplain.utils.llm_utils import get_llm_instance

# validate name
assert (
re.match(r"^[a-zA-Z0-9 \-\(\)]*$", self.name) is not None
), "Agent Creation Error: Agent name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed."
assert re.match(r"^[a-zA-Z0-9 \-\(\)]*$", self.name) is not None, (
"Agent Creation Error: Agent name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed."
)

llm = get_llm_instance(self.llm_id, api_key=self.api_key, use_cache=True)

Expand Down Expand Up @@ -233,6 +240,14 @@ def validate(self, raise_exception: bool = False) -> bool:
return self.is_valid

def generate_session_id(self, history: list = None) -> str:
"""Generate a unique session ID for agent conversations.

Args:
history (list, optional): Previous conversation history. Defaults to None.

Returns:
str: A unique session identifier based on timestamp and random components.
"""
if history:
validate_history(history)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
Expand Down Expand Up @@ -307,6 +322,7 @@ def run(
max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10.
output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.

Returns:
Dict: parsed output from model
"""
Expand Down Expand Up @@ -406,10 +422,10 @@ def run_async(
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
output_format (ResponseFormat, optional): response format. Defaults to TEXT.
evolve (Union[Dict[str, Any], EvolveParam, None], optional): evolve the agent configuration. Can be a dictionary, EvolveParam instance, or None.

Returns:
dict: polling URL in response
"""

if session_id is not None and history is not None:
raise ValueError("Provide either `session_id` or `history`, not both.")

Expand All @@ -434,7 +450,9 @@ def run_async(
assert data is not None or query is not None, "Either 'data' or 'query' must be provided."
if data is not None:
if isinstance(data, dict):
assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided."
assert "query" in data and data["query"] is not None, (
"When providing a dictionary, 'query' must be provided."
)
query = data.get("query")
if session_id is None:
session_id = data.get("session_id")
Expand All @@ -447,7 +465,9 @@ def run_async(

# process content inputs
if content is not None:
assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text."
assert FileFactory.check_storage_type(query) == StorageType.TEXT, (
"When providing 'content', query must be text."
)

if isinstance(content, list):
assert len(content) <= 3, "The maximum number of content inputs is 3."
Expand Down Expand Up @@ -511,6 +531,11 @@ def run_async(
)

def to_dict(self) -> Dict:
"""Convert the Agent instance to a dictionary representation.

Returns:
Dict: Dictionary containing the agent's configuration and metadata.
"""
from aixplain.factories.agent_factory.utils import build_tool_payload

return {
Expand Down Expand Up @@ -674,9 +699,9 @@ def delete(self) -> None:
"referencing it."
)
else:
message = f"Agent Deletion Error (HTTP {r.status_code}): " f"{error_message}."
message = f"Agent Deletion Error (HTTP {r.status_code}): {error_message}."
except ValueError:
message = f"Agent Deletion Error (HTTP {r.status_code}): " "There was an error in deleting the agent."
message = f"Agent Deletion Error (HTTP {r.status_code}): There was an error in deleting the agent."
logging.error(message)
raise Exception(message)

Expand All @@ -701,7 +726,7 @@ def update(self) -> None:
stack = inspect.stack()
if len(stack) > 2 and stack[1].function != "save":
warnings.warn(
"update() is deprecated and will be removed in a future version. " "Please use save() instead.",
"update() is deprecated and will be removed in a future version. Please use save() instead.",
DeprecationWarning,
stacklevel=2,
)
Expand Down
Loading
Loading