diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 555f4920..4f0364e1 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -16,3 +16,5 @@ from .sort_order import SortOrder from .response_status import ResponseStatus from .database_source import DatabaseSourceType +from .embedding_model import EmbeddingModel +from .asset_status import AssetStatus diff --git a/aixplain/enums/embedding_model.py b/aixplain/enums/embedding_model.py new file mode 100644 index 00000000..8769e3dd --- /dev/null +++ b/aixplain/enums/embedding_model.py @@ -0,0 +1,30 @@ +__author__ = "aiXplain" + +""" +Copyright 2023 The aiXplain SDK authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +Author: aiXplain team +Date: February 17th 2025 +Description: + Embedding Model Enum +""" + +from enum import Enum + + +class EmbeddingModel(Enum): + SNOWFLAKE_ARCTIC_EMBED_M_LONG = "6658d40729985c2cf72f42ec" + OPENAI_ADA002 = "6734c55df127847059324d9e" + SNOWFLAKE_ARCTIC_EMBED_L_V2_0 = "678a4f8547f687504744960a" + JINA_CLIP_V2_MULTIMODAL = "67c5f705d8f6a65d6f74d732" + + def __str__(self): + return self._value_ diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index af59f17f..9532bf72 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -43,6 +43,7 @@ from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin +from aixplain.enums import DatabaseSourceType class AgentFactory: @@ -197,7 +198,7 @@ def create_sql_tool( cls, description: Text, source: str, - source_type: str, + source_type: Union[str, DatabaseSourceType], schema: Optional[Text] = None, tables: Optional[List[Text]] = None, enable_commit: bool = False, @@ -207,7 +208,7 @@ def create_sql_tool( Args: description (Text): description of the database tool source (Union[Text, Dict]): database source - can be a connection string or dictionary with connection details - source_type (Text): type of source (postgresql, sqlite, csv) + source_type (Union[str, DatabaseSourceType]): type of source (postgresql, sqlite, csv) or DatabaseSourceType enum schema (Optional[Text], optional): database schema description tables (Optional[List[Text]], optional): table names to work with (optional) enable_commit (bool, optional): enable to modify the database (optional) @@ -237,7 +238,6 @@ def create_sql_tool( get_table_schema, get_table_names_from_schema, ) - from aixplain.enums import DatabaseSourceType if not source: raise SQLToolError("Source must be provided") @@ -245,10 +245,16 @@ def create_sql_tool( raise SQLToolError("Source type must be provided") # Validate source type - try: - source_type = DatabaseSourceType.from_string(source_type) - except ValueError as e: - raise SQLToolError(str(e)) + if isinstance(source_type, str): + try: + source_type = DatabaseSourceType.from_string(source_type) + except ValueError as e: + raise SQLToolError(str(e)) + elif isinstance(source_type, DatabaseSourceType): + # Already the correct type, no conversion needed + pass + else: + raise SQLToolError(f"Source type must be either a string or DatabaseSourceType enum, got {type(source_type)}") database_path = None # Final database path to pass to SQLTool @@ -258,14 +264,17 @@ def create_sql_tool( raise SQLToolError(f"CSV file '{source}' does not exist") if not source.endswith(".csv"): raise SQLToolError(f"File '{source}' is not a CSV file") + if tables and len(tables) > 1: + raise SQLToolError("CSV source type only supports one table") # Create database name from CSV filename or use custom table name base_name = os.path.splitext(os.path.basename(source))[0] db_path = os.path.join(os.path.dirname(source), f"{base_name}.db") + table_name = tables[0] if tables else None try: # Create database from CSV - schema = create_database_from_csv(source, db_path) + schema = create_database_from_csv(source, db_path, table_name) database_path = db_path # Get table names if not provided diff --git a/aixplain/factories/index_factory.py b/aixplain/factories/index_factory.py index 85be5cc6..7588e583 100644 --- a/aixplain/factories/index_factory.py +++ b/aixplain/factories/index_factory.py @@ -1,16 +1,20 @@ from aixplain.modules.model.index_model import IndexModel from aixplain.factories import ModelFactory -from aixplain.enums import Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier +from aixplain.enums import EmbeddingModel, Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier from typing import Optional, Text, Union, List, Tuple +AIR_MODEL_ID = "66eae6656eb56311f2595011" + class IndexFactory(ModelFactory): @classmethod - def create(cls, name: Text, description: Text) -> IndexModel: + def create( + cls, name: Text, description: Text, embedding_model: EmbeddingModel = EmbeddingModel.OPENAI_ADA002 + ) -> IndexModel: """Create a new index collection""" - model = cls.get("66eae6656eb56311f2595011") + model = cls.get(AIR_MODEL_ID) - data = {"data": name, "description": description} + data = {"data": name, "description": description, "model": embedding_model.value} response = model.run(data=data) if response.status == ResponseStatus.SUCCESS: model_id = response.data @@ -19,7 +23,7 @@ def create(cls, name: Text, description: Text) -> IndexModel: error_message = f"Index Factory Exception: {response.error_message}" if error_message == "": - error_message = "Index Factory Exception:An error occurred while creating the index collection." + error_message = "Index Factory Exception: An error occurred while creating the index collection." raise Exception(error_message) @classmethod diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index e17841e6..c2611d6c 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -23,15 +23,15 @@ import json import logging +from typing import Dict, List, Optional, Text, Union +from urllib.parse import urljoin from aixplain.enums.supplier import Supplier from aixplain.modules.agent import Agent -from aixplain.modules.team_agent import TeamAgent +from aixplain.modules.team_agent import TeamAgent, InspectorTarget from aixplain.utils import config from aixplain.factories.team_agent_factory.utils import build_team_agent from aixplain.utils.file_utils import _request_with_retry -from typing import Dict, List, Optional, Text, Union -from urllib.parse import urljoin class TeamAgentFactory: @@ -47,9 +47,29 @@ def create( version: Optional[Text] = None, use_mentalist: bool = True, use_inspector: bool = True, + num_inspectors: int = 1, + inspector_targets: List[Union[InspectorTarget, Text]] = [InspectorTarget.STEPS], use_mentalist_and_inspector: bool = False, # TODO: remove this ) -> TeamAgent: - """Create a new team agent in the platform.""" + """Create a new team agent in the platform. + + Args: + name: The name of the team agent. + agents: A list of agents to be added to the team. + llm_id: The ID of the LLM to be used for the team agent. + description: The description of the team agent. + api_key: The API key to be used for the team agent. + supplier: The supplier of the team agent. + version: The version of the team agent. + use_mentalist: Whether to use the mentalist agent. + use_inspector: Whether to use the inspector agent. + num_inspectors: The number of inspectors to be used for each inspection. + inspector_targets: Which stages to be inspected during an execution of the team agent. (steps, output) + use_mentalist_and_inspector: Whether to use the mentalist and inspector agents. (legacy) + + Returns: + A new team agent instance. + """ assert len(agents) > 0, "TeamAgent Onboarding Error: At least one agent must be provided." agent_list = [] for agent in agents: @@ -68,8 +88,22 @@ def create( assert isinstance(agent, Agent), "TeamAgent Onboarding Error: Agents must be instances of Agent class" agent_list.append(agent_obj) - if use_inspector and not use_mentalist: - raise Exception("TeamAgent Onboarding Error: To use the Inspector agent, you must enable Mentalist.") + # NOTE: backend expects max_inspectors (for "generated" inspectors) + max_inspectors = num_inspectors + + if use_inspector: + try: + # convert to enum if string and check its validity + inspector_targets = [InspectorTarget(target) for target in inspector_targets] + except ValueError: + raise ValueError("TeamAgent Onboarding Error: Invalid inspector target. Valid targets are: steps, output") + + if not use_mentalist: + raise Exception("TeamAgent Onboarding Error: To use the Inspector agent, you must enable Mentalist.") + if max_inspectors < 1: + raise Exception( + "TeamAgent Onboarding Error: The number of inspectors must be greater than 0 when using the Inspector agent." + ) if use_mentalist_and_inspector: mentalist_llm_id = llm_id @@ -100,6 +134,8 @@ def create( "supervisorId": llm_id, "plannerId": mentalist_llm_id, "inspectorId": inspector_llm_id, + "maxInspectors": max_inspectors, + "inspectorTargets": inspector_targets if use_inspector else [], "supplier": supplier, "version": version, "status": "draft", diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index 5e865cd0..debadff8 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -1,12 +1,14 @@ __author__ = "lucaspavanelli" import logging +from typing import Dict, Text, List +from urllib.parse import urljoin + import aixplain.utils.config as config from aixplain.enums.asset_status import AssetStatus from aixplain.modules.agent import Agent -from aixplain.modules.team_agent import TeamAgent -from typing import Dict, Text, List -from urllib.parse import urljoin +from aixplain.modules.team_agent import TeamAgent, InspectorTarget + GPT_4o_ID = "6646261c6eb563165658bbb1" @@ -29,6 +31,8 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = ) continue + inspector_targets = [InspectorTarget(target.lower()) for target in payload.get("inspectorTargets", [])] + team_agent = TeamAgent( id=payload.get("id", ""), name=payload.get("name", ""), @@ -40,6 +44,8 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = llm_id=payload.get("llmId", GPT_4o_ID), use_mentalist=True if payload.get("plannerId", None) is not None else False, use_inspector=True if payload.get("inspectorId", None) is not None else False, + max_inspectors=payload.get("maxInspectors", 1), + inspector_targets=inspector_targets, api_key=api_key, status=AssetStatus(payload["status"]), ) diff --git a/aixplain/modules/agent/tool/custom_python_code_tool.py b/aixplain/modules/agent/tool/custom_python_code_tool.py index 6e651eeb..511e4a8f 100644 --- a/aixplain/modules/agent/tool/custom_python_code_tool.py +++ b/aixplain/modules/agent/tool/custom_python_code_tool.py @@ -36,6 +36,7 @@ def __init__(self, code: Union[Text, Callable], description: Text = "", **additi def to_dict(self): return { + "name": self.name, "description": self.description, "type": "utility", "utility": "custom_python_code", diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index f0cb88e7..ba16317a 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -42,6 +42,7 @@ def __init__( function: Optional[Union[Function, Text]] = None, supplier: Optional[Union[Dict, Supplier]] = None, model: Optional[Union[Text, Model]] = None, + name: Optional[Text] = None, description: Text = "", parameters: Optional[Dict] = None, **additional_info, @@ -59,7 +60,8 @@ def __init__( function is not None or model is not None ), "Agent Creation Error: Either function or model must be provided when instantiating a tool." - super().__init__(name="", description=description, **additional_info) + name = name or "" + super().__init__(name=name, description=description, **additional_info) if function is not None: if isinstance(function, str): function = Function(function) @@ -104,6 +106,7 @@ def to_dict(self) -> Dict: return { "function": self.function.value if self.function is not None else None, "type": "model", + "name": self.name, "description": self.description, "supplier": supplier, "version": self.version if self.version else None, diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index ab3b4311..13d3c46f 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -20,7 +20,7 @@ Description: Agentification Class """ -from typing import Text, Union +from typing import Text, Union, Optional from aixplain.modules.agent.tool import Tool from aixplain.modules.pipeline import Pipeline @@ -38,6 +38,7 @@ def __init__( self, description: Text, pipeline: Union[Text, Pipeline], + name: Optional[Text] = None, **additional_info, ) -> None: """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. @@ -46,7 +47,9 @@ def __init__( description (Text): description of the tool pipeline (Union[Text, Pipeline]): pipeline """ - super().__init__("", description, **additional_info) + name = name or "" + super().__init__(name=name, description=description, **additional_info) + if isinstance(pipeline, Pipeline): pipeline = pipeline.id self.pipeline = pipeline @@ -54,6 +57,7 @@ def __init__( def to_dict(self): return { "assetId": self.pipeline, + "name": self.name, "description": self.description, "type": "pipeline", } diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index 6935e6df..81444db0 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -127,7 +127,7 @@ def get_table_schema(database_path: str) -> str: raise DatabaseError(f"Unexpected error while getting table schema: {str(e)}") -def create_database_from_csv(csv_path: str, database_path: str) -> str: +def create_database_from_csv(csv_path: str, database_path: str, table_name: str = None) -> str: """Create SQLite database from CSV file and return the schema""" if not os.path.exists(csv_path): raise CSVError(f"CSV file '{csv_path}' does not exist") @@ -161,7 +161,9 @@ def create_database_from_csv(csv_path: str, database_path: str) -> str: cursor = conn.cursor() # Check if table already exists - table_name = clean_column_name(os.path.splitext(os.path.basename(csv_path))[0]) + + if table_name is None: + table_name = clean_column_name(os.path.splitext(os.path.basename(csv_path))[0]) cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'") if cursor.fetchone(): warnings.warn(f"Table '{table_name}' already exists in the database and will be replaced") @@ -262,6 +264,7 @@ def __init__( schema: Optional[Text] = None, tables: Optional[Union[List[Text], Text]] = None, enable_commit: bool = False, + name: Optional[Text] = None, **additional_info, ) -> None: """Tool to execute SQL query commands in an SQLite database. @@ -273,7 +276,10 @@ def __init__( tables (Optional[Union[List[Text], Text]]): table names to work with (optional) enable_commit (bool): enable to modify the database (optional) """ - super().__init__("", description, **additional_info) + + name = name or "" + super().__init__(name=name, description=description, **additional_info) + self.database = database self.schema = schema self.tables = tables if isinstance(tables, list) else [tables] if tables else None @@ -281,6 +287,7 @@ def __init__( def to_dict(self) -> Dict[str, Text]: return { + "name": self.name, "description": self.description, "parameters": [ {"name": "database", "value": self.database}, diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 5788daab..adedfcfb 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -243,7 +243,7 @@ def run( start = time.time() payload = build_payload(data=data, parameters=parameters) url = f"{self.url}/{self.id}".replace("api/v1/execute", "api/v2/execute") - logging.debug(f"Model Run Sync: Start service for {name} - {url}") + logging.debug(f"Model Run Sync: Start service for {name} - {url} - {payload}") response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) if response["status"] == "IN_PROGRESS": try: @@ -281,8 +281,8 @@ def run_async( dict: polling URL in response """ url = f"{self.url}/{self.id}" - logging.debug(f"Model Run Async: Start service for {name} - {url}") payload = build_payload(data=data, parameters=parameters) + logging.debug(f"Model Run Async: Start service for {name} - {url} - {payload}") response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return ModelResponse( status=response.pop("status", ResponseStatus.FAILED), diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index fae597b8..b86b1ee5 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -1,12 +1,42 @@ -from aixplain.enums import Function, Supplier, ResponseStatus +from aixplain.enums import EmbeddingModel, Function, Supplier, ResponseStatus, StorageType from aixplain.modules.model import Model from aixplain.utils import config from aixplain.modules.model.response import ModelResponse from typing import Text, Optional, Union, Dict from aixplain.modules.model.record import Record +from enum import Enum from typing import List +class IndexFilterOperator(Enum): + EQUALS = "==" + NOT_EQUALS = "!=" + CONTAINS = "in" + NOT_CONTAINS = "not in" + GREATER_THAN = ">" + LESS_THAN = "<" + GREATER_THAN_OR_EQUALS = ">=" + LESS_THAN_OR_EQUALS = "<=" + + +class IndexFilter: + field: str + value: str + operator: Union[IndexFilterOperator, str] + + def __init__(self, field: str, value: str, operator: Union[IndexFilterOperator, str]): + self.field = field + self.value = value + self.operator = operator + + def to_dict(self): + return { + "field": self.field, + "value": self.value, + "operator": self.operator.value if isinstance(self.operator, IndexFilterOperator) else self.operator, + } + + class IndexModel(Model): def __init__( self, @@ -19,6 +49,7 @@ def __init__( function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, + embedding_model: Optional[EmbeddingModel] = None, **additional_info, ) -> None: """Index Init @@ -33,6 +64,7 @@ def __init__( function (Function, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. cost (Dict, optional): model price. Defaults to None. + embedding_model (EmbeddingModel, optional): embedding model. Defaults to None. **additional_info: Any additional Model info to be saved """ assert function == Function.SEARCH, "Index only supports search function" @@ -50,14 +82,61 @@ def __init__( ) self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL + self.embedding_model = embedding_model + + def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: + """Search for documents in the index + + Args: + query (str): Query to be searched + top_k (int, optional): Number of results to be returned. Defaults to 10. + filters (List[IndexFilter], optional): Filters to be applied. Defaults to []. - def search(self, query: str, top_k: int = 10, filters: Dict = {}) -> ModelResponse: - data = {"action": "search", "data": query, "payload": {"filters": filters, "top_k": top_k}} + Returns: + ModelResponse: Response from the indexing service + + Example: + - index_model.search("Hello") + - index_model.search("", filters=[IndexFilter(field="category", value="animate", operator=IndexFilterOperator.EQUALS)]) + """ + from aixplain.factories import FileFactory + + uri, value_type = "", "text" + storage_type = FileFactory.check_storage_type(query) + if storage_type in [StorageType.FILE, StorageType.URL]: + uri = FileFactory.to_link(query) + query = "" + value_type = "image" + + data = { + "action": "search", + "data": query or uri, + "dataType": value_type, + "filters": [filter.to_dict() for filter in filters], + "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, + } return self.run(data=data) def upsert(self, documents: List[Record]) -> ModelResponse: + """Upsert documents into the index + + Args: + documents (List[Record]): List of documents to be upserted + + Returns: + ModelResponse: Response from the indexing service + + Example: + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) + """ + # Validate documents + for doc in documents: + doc.validate() + # Convert documents to payloads payloads = [doc.to_dict() for doc in documents] - data = {"action": "ingest", "data": "", "payload": {"payloads": payloads}} + # Build payload + data = {"action": "ingest", "data": payloads} + # Run the indexing service response = self.run(data=data) if response.status == ResponseStatus.SUCCESS: response.data = payloads diff --git a/aixplain/modules/model/record.py b/aixplain/modules/model/record.py index a3c57173..4cc958d8 100644 --- a/aixplain/modules/model/record.py +++ b/aixplain/modules/model/record.py @@ -1,9 +1,17 @@ +from aixplain.enums import DataType, StorageType from typing import Optional from uuid import uuid4 class Record: - def __init__(self, value: str, value_type: str = "text", id: Optional[str] = None, uri: str = "", attributes: dict = {}): + def __init__( + self, + value: str = "", + value_type: DataType = DataType.TEXT, + id: Optional[str] = None, + uri: str = "", + attributes: dict = {}, + ): self.value = value self.value_type = value_type self.id = id if id is not None else str(uuid4()) @@ -12,9 +20,30 @@ def __init__(self, value: str, value_type: str = "text", id: Optional[str] = Non def to_dict(self): return { - "value": self.value, - "value_type": self.value_type, - "id": self.id, + "data": self.value, + "dataType": str(self.value_type), + "document_id": self.id, "uri": self.uri, "attributes": self.attributes, } + + def validate(self): + """Validate the record""" + from aixplain.factories import FileFactory + from aixplain.modules.model.utils import is_supported_image_type + + assert self.value_type in [DataType.TEXT, DataType.IMAGE], "Index Upsert Error: Invalid value type" + if self.value_type == DataType.IMAGE: + assert self.uri is not None and self.uri != "", "Index Upsert Error: URI is required for image records" + else: + assert self.value is not None and self.value != "", "Index Upsert Error: Value is required for text records" + + storage_type = FileFactory.check_storage_type(self.uri) + + # Check if value is an image file or URL + if storage_type in [StorageType.FILE, StorageType.URL]: + if is_supported_image_type(self.uri): + self.value_type = DataType.IMAGE + self.uri = FileFactory.to_link(self.uri) if storage_type == StorageType.FILE else self.uri + else: + raise Exception(f"Index Upsert Error: Unsupported file type ({self.uri})") diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index a2fbd05b..6fa7d9a5 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -19,6 +19,7 @@ Utility Model Class """ import logging +import warnings from aixplain.enums import Function, Supplier, DataType from aixplain.enums.asset_status import AssetStatus from aixplain.modules.model import Model @@ -43,17 +44,20 @@ def validate(self): def to_dict(self): return {"name": self.name, "description": self.description, "type": self.type.value} + # Tool decorator -def utility_tool(name: Text, description: Text, inputs: List[UtilityModelInput] = None, output_examples: Text = "", status = AssetStatus.DRAFT): +def utility_tool( + name: Text, description: Text, inputs: List[UtilityModelInput] = None, output_examples: Text = "", status=AssetStatus.DRAFT +): """Decorator for utility tool functions - + Args: name: Name of the utility tool description: Description of what the utility tool does inputs: List of input parameters, must be UtilityModelInput objects output_examples: Examples of expected outputs status: Asset status - + Raises: ValueError: If name or description is empty TypeError: If inputs contains non-UtilityModelInput objects @@ -63,7 +67,7 @@ def utility_tool(name: Text, description: Text, inputs: List[UtilityModelInput] raise ValueError("Utility tool name cannot be empty") if not description or not description.strip(): raise ValueError("Utility tool description cannot be empty") - + # Validate inputs if inputs is not None: if not isinstance(inputs, list): @@ -71,7 +75,7 @@ def utility_tool(name: Text, description: Text, inputs: List[UtilityModelInput] for input_param in inputs: if not isinstance(input_param, UtilityModelInput): raise TypeError(f"Invalid input parameter: {input_param}. All inputs must be UtilityModelInput objects") - + def decorator(func): func._is_utility_tool = True # Mark function as utility tool func._tool_name = name.strip() @@ -80,12 +84,16 @@ def decorator(func): func._tool_output_examples = output_examples func._tool_status = status return func + return decorator class UtilityModel(Model): """Ready-to-use Utility Model. + Note: Non-deployed utility models (status=DRAFT) will expire after 24 hours after creation. + Use the .deploy() method to make the model permanent. + Attributes: id (Text): ID of the Model name (Text): Name of the Model @@ -116,7 +124,7 @@ def __init__( function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, - status: AssetStatus = AssetStatus.ONBOARDED,# TODO: change to draft when we have the backend ready + status: AssetStatus = AssetStatus.DRAFT, **additional_info, ) -> None: """Utility Model Init @@ -161,6 +169,13 @@ def __init__( status = AssetStatus.DRAFT self.status = status + if status == AssetStatus.DRAFT: + warnings.warn( + "WARNING: Non-deployed utility models (status=DRAFT) will expire after 24 hours after creation. " + "Use .deploy() method to make the model permanent.", + UserWarning, + ) + def validate(self): """Validate the Utility Model.""" description = None @@ -187,7 +202,6 @@ def validate(self): assert self.description and self.description.strip() != "", "Description is required" assert self.code and self.code.strip() != "", "Code is required" - def _model_exists(self): if self.id is None or self.id == "": return False @@ -199,7 +213,6 @@ def _model_exists(self): raise Exception() return True - def to_dict(self): return { "name": self.name, @@ -225,7 +238,6 @@ def update(self): stacklevel=2, ) - self.validate() url = urljoin(self.backend_url, f"sdk/utilities/{self.id}") headers = {"x-api-key": f"{self.api_key}", "Content-Type": "application/json"} diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 7b494305..7ba42f2d 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -123,8 +123,8 @@ def parse_code(code: Union[Text, Callable]) -> Tuple[Text, List, Text, Text]: "Utility Model Error:If the function is not decorated with @utility_tool, the description must be provided in the docstring" ) # get parameters of the function - f = re.findall(r"main\((.*?(?:\s*=\s*[^,)]+)?(?:\s*,\s*.*?(?:\s*=\s*[^,)]+)?)*)\)", str_code) - parameters = f[0].split(",") if len(f) > 0 else [] + params_match = re.search(r"def\s+\w+\s*\((.*?)\)\s*(?:->.*?)?:", str_code) + parameters = params_match.group(1).split(",") if params_match else [] for input in parameters: assert ( @@ -176,6 +176,12 @@ def parse_code_decorated(code: Union[Text, Callable]) -> Tuple[Text, List, Text] inputs, description, name = [], "", "" str_code = "" + # Add explicit type checking for class instances + if inspect.isclass(code) or (not isinstance(code, (str, Callable)) and hasattr(code, "__class__")): + raise TypeError( + f"Code must be either a string or a callable function, not a class or class instance. You tried to pass a class or class instance: {code}" + ) + if isinstance(code, Callable) and hasattr(code, "_is_utility_tool"): str_code = inspect.getsource(code) # Use the information directly from the decorated callable @@ -212,7 +218,7 @@ def parse_code_decorated(code: Union[Text, Callable]) -> Tuple[Text, List, Text] description = code.__doc__.strip() if code.__doc__ else "" name = code.__name__ # Try to infer parameters - params_match = re.search(r"def\s+\w+\s*\((.*?)\):", str_code) + params_match = re.search(r"def\s+\w+\s*\((.*?)\)\s*(?:->.*?)?:", str_code) parameters = params_match.group(1).split(",") if params_match else [] for input in parameters: @@ -338,3 +344,7 @@ def parse_code_decorated(code: Union[Text, Callable]) -> Tuple[Text, List, Text] os.remove(local_path) return code, inputs, description, name + + +def is_supported_image_type(value: str) -> bool: + return any(value.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"]) diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 16f998d6..0642e6ee 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -115,23 +115,26 @@ def __polling( try: response_body = self.poll(poll_url, name=name) logging.debug(f"Polling for Pipeline: Status of polling for {name} : {response_body}") - end = time.time() if not response_body["completed"]: time.sleep(wait_time) if wait_time < 60: wait_time *= 1.1 except Exception: - logging.error(f"Polling for Pipeline: polling for {name} : Continue") + logging.error(f"Polling for Pipeline '{self.id}': polling for {name} ({poll_url}): Continue") break if response_body["status"] == ResponseStatus.SUCCESS: try: - logging.debug(f"Polling for Pipeline: Final status of polling for {name} : SUCCESS - {response_body}") + logging.debug( + f"Polling for Pipeline '{self.id}' - Final status of polling for {name} ({poll_url}): SUCCESS - {response_body}" + ) except Exception: - logging.error(f"Polling for Pipeline: Final status of polling for {name} : ERROR - {response_body}") + logging.error( + f"Polling for Pipeline '{self.id}' - Final status of polling for {name} ({poll_url}): ERROR - {response_body}" + ) else: logging.error( - f"Polling for Pipeline: Final status of polling for {name} : No response in {timeout} seconds - {response_body}" + f"Polling for Pipeline '{self.id}' - Final status of polling for {name} ({poll_url}): No response in {timeout} seconds - {response_body}" ) return response_body @@ -160,8 +163,7 @@ def poll( resp["data"] = json.loads(resp["data"])["response"] except Exception: resp = r.json() - logging.info(f"Single Poll for Pipeline: Status of polling for {name} : {resp}") - + logging.info(f"Single Poll for Pipeline '{self.id}' - Status of polling for {name} ({poll_url}): {resp}") if response_version == "v1": return resp status = ResponseStatus(resp.pop("status", "failed")) @@ -397,7 +399,6 @@ def run_async( if 200 <= r.status_code < 300: resp = r.json() logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - if response_version == "v1": return resp res = PipelineResponse( @@ -425,7 +426,7 @@ def run_async( f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." ) - logging.error(f"Error in request for {name} - {r.status_code}: {error}") + logging.error(f"Error in request for {name} (Pipeline ID '{self.id}') - {r.status_code}: {error}") if response_version == "v1": return { "status": "failed", diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 3e00ab61..87c78159 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -26,8 +26,11 @@ import time import traceback import re +from enum import Enum +from typing import Dict, List, Text, Optional, Union +from urllib.parse import urljoin -from aixplain.utils.file_utils import _request_with_retry +from aixplain.enums import ResponseStatus from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier from aixplain.enums.asset_status import AssetStatus @@ -35,13 +38,18 @@ from aixplain.modules.model import Model from aixplain.modules.agent import Agent, OutputFormat from aixplain.modules.agent.agent_response import AgentResponse -from aixplain.enums import ResponseStatus from aixplain.modules.agent.utils import process_variables -from typing import Dict, List, Text, Optional, Union -from urllib.parse import urljoin +from aixplain.utils import config +from aixplain.utils.file_utils import _request_with_retry -from aixplain.utils import config +class InspectorTarget(str, Enum): + # TODO: INPUT + STEPS = "steps" + OUTPUT = "output" + + def __str__(self): + return self._value_ class TeamAgent(Model): @@ -76,6 +84,8 @@ def __init__( cost: Optional[Dict] = None, use_mentalist: bool = True, use_inspector: bool = True, + max_inspectors: int = 1, + inspector_targets: List[InspectorTarget] = [InspectorTarget.STEPS], status: AssetStatus = AssetStatus.DRAFT, **additional_info, ) -> None: @@ -100,6 +110,8 @@ def __init__( self.llm_id = llm_id self.use_mentalist = use_mentalist self.use_inspector = use_inspector + self.max_inspectors = max_inspectors + self.inspector_targets = inspector_targets if isinstance(status, str): try: @@ -207,13 +219,9 @@ def run_async( from aixplain.factories.file_factory import FileFactory if not self.is_valid: - raise Exception( - "Team Agent is not valid. Please validate the team agent before running." - ) + raise Exception("Team Agent is not valid. Please validate the team agent before running.") - assert ( - data is not None or query is not None - ), "Either 'data' or 'query' must be provided." + assert data is not None or query is not None, "Either 'data' or 'query' must be provided." if data is not None: if isinstance(data, dict): assert ( @@ -260,16 +268,8 @@ def run_async( "sessionId": session_id, "history": history, "executionParams": { - "maxTokens": ( - parameters["max_tokens"] - if "max_tokens" in parameters - else max_tokens - ), - "maxIterations": ( - parameters["max_iterations"] - if "max_iterations" in parameters - else max_iterations - ), + "maxTokens": (parameters["max_tokens"] if "max_tokens" in parameters else max_tokens), + "maxIterations": (parameters["max_iterations"] if "max_iterations" in parameters else max_iterations), "outputFormat": output_format.value, }, } @@ -327,6 +327,8 @@ def to_dict(self) -> Dict: "supervisorId": self.llm_id, "plannerId": self.llm_id if self.use_mentalist else None, "inspectorId": self.llm_id if self.use_inspector else None, + "maxInspectors": self.max_inspectors, + "inspectorTargets": [target.value for target in self.inspector_targets], "supplier": self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier, "version": self.version, "status": self.status.value, @@ -362,9 +364,7 @@ def validate(self, raise_exception: bool = False) -> bool: raise e else: logging.warning(f"Team Agent Validation Error: {e}") - logging.warning( - "You won't be able to run the Team Agent until the issues are handled manually." - ) + logging.warning("You won't be able to run the Team Agent until the issues are handled manually.") return self.is_valid @@ -377,7 +377,8 @@ def update(self) -> None: stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " + "Please use save() instead.", DeprecationWarning, stacklevel=2, ) @@ -389,9 +390,7 @@ def update(self) -> None: payload = self.to_dict() - logging.debug( - f"Start service for PUT Update Team Agent - {url} - {headers} - {json.dumps(payload)}" - ) + logging.debug(f"Start service for PUT Update Team Agent - {url} - {headers} - {json.dumps(payload)}") resp = "No specified error." try: r = _request_with_retry("put", url, headers=headers, json=payload) diff --git a/aixplain/processes/data_onboarding/process_interval_files.py b/aixplain/processes/data_onboarding/process_interval_files.py deleted file mode 100644 index 576cd434..00000000 --- a/aixplain/processes/data_onboarding/process_interval_files.py +++ /dev/null @@ -1,190 +0,0 @@ -__author__ = "thiagocastroferreira" - -import json -import logging -import os -import pandas as pd -import tarfile - -from aixplain.enums.data_type import DataType -from aixplain.enums.file_type import FileType -from aixplain.enums.storage_type import StorageType -from aixplain.modules.content_interval import ( - ContentInterval, - AudioContentInterval, - ImageContentInterval, - TextContentInterval, - VideoContentInterval, -) -from aixplain.modules.file import File -from aixplain.modules.metadata import MetaData -from aixplain.utils.file_utils import upload_data -from pathlib import Path -from tqdm import tqdm -from typing import Any, Dict, List, Text, Tuple - - -def compress_folder(folder_path: str): - with tarfile.open(folder_path + ".tgz", "w:gz") as tar: - for name in os.listdir(folder_path): - tar.add(os.path.join(folder_path, name)) - return folder_path + ".tgz" - - -def process_interval(interval: Any, storage_type: StorageType, interval_folder: Text) -> List[Dict]: - """Process text files - - Args: - intervals (Any): content intervals to process the content - storage_type (StorageType): type of storage: URL, local path or textual content - - Returns: - List[Dict]: content interval - """ - if storage_type == StorageType.FILE: - # Check the size of file and assert a limit of 50 MB - assert ( - os.path.getsize(interval.content) <= 25000000 - ), f'Data Asset Onboarding Error: Local text file "{interval}" exceeds the size limit of 25 MB.' - fname = os.path.basename(interval) - new_path = os.path.join(audio_folder, fname) - if os.path.exists(new_path) is False: - shutil.copy2(audio_path, new_path) - return [interval.__dict__ for interval in intervals] - - -def validate_format(index: int, interval: Dict, metadata: MetaData) -> ContentInterval: - """Validate the interval format - - Args: - index (int): row index - interval (Dict): interval to be validated - metadata (MetaData): metadata - - Returns: - ContentInterval: _description_ - """ - if metadata.dtype == DataType.AUDIO_INTERVAL: - try: - if isinstance(interval, list): - interval = [AudioContentInterval(**interval_) for interval_ in interval] - else: - interval = [AudioContentInterval(**interval)] - except Exception as e: - message = f'Data Asset Onboarding Error: Audio Interval in row {index} of Column "{metadata.name}" is not following the format. Check the "AudioContentInterval" class for the correct format.' - logging.exception(message) - raise Exception(message) - elif metadata.dtype == DataType.IMAGE_INTERVAL: - try: - if isinstance(interval, list): - interval = [ImageContentInterval(**interval_) for interval_ in interval] - else: - interval = [ImageContentInterval(**interval)] - except Exception as e: - message = f'Data Asset Onboarding Error: Image Interval in row {index} of Column "{metadata.name}" is not following the format. Check the "ImageContentInterval" class for the correct format.' - logging.exception(message) - raise Exception(message) - elif metadata.dtype == DataType.TEXT_INTERVAL: - try: - if isinstance(interval, list): - interval = [TextContentInterval(**interval_) for interval_ in interval] - else: - interval = [TextContentInterval(**interval)] - except Exception as e: - message = f'Data Asset Onboarding Error: Text Interval in row {index} of Column "{metadata.name}" is not following the format. Check the "TextContentInterval" class for the correct format.' - logging.exception(message) - raise Exception(message) - elif metadata.dtype == DataType.VIDEO_INTERVAL: - try: - if isinstance(interval, list): - interval = [VideoContentInterval(**interval_) for interval_ in interval] - else: - interval = [VideoContentInterval(**interval)] - except Exception as e: - message = f'Data Asset Onboarding Error: Video Interval in row {index} of Column "{metadata.name}" is not following the format. Check the "VideoContentInterval" class for the correct format.' - logging.exception(message) - raise Exception(message) - return interval - - -def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 1000) -> Tuple[List[File], int, int]: - """Process a list of local interval files, compress and upload them to pre-signed URLs in S3 - - Explanation: - Each interval on "paths" is processed. If the interval content is in a public link or local file, it will be downloaded and added to an index CSV file. - The intervals are processed in batches such that at each "batch_size" texts, the index CSV file is uploaded into a pre-signed URL in s3 and reset. - - Args: - metadata (MetaData): meta data of the asset - paths (List): list of paths to local files - folder (Path): local folder to save compressed files before upload them to s3. - - Returns: - Tuple[List[File], int, int]: list of s3 links, data colum index and number of rows - """ - logging.debug(f'Data Asset Onboarding: Processing "{metadata.name}".') - interval_folder = Path(".") - if metadata.storage_type in [StorageType.FILE, StorageType.TEXT]: - interval_folder = Path(os.path.join(folder, "data")) - interval_folder.mkdir(exist_ok=True) - - idx = 0 - data_column_idx = -1 - files, batch = [], [] - for i in tqdm(range(len(paths)), desc=f' Data "{metadata.name}" onboarding progress', position=1, leave=False): - path = paths[i] - try: - dataframe = pd.read_csv(path) - except Exception as e: - message = f'Data Asset Onboarding Error: Local file "{path}" not found.' - logging.exception(message) - raise Exception(message) - - # process intervals - for j in tqdm(range(len(dataframe)), desc=" File onboarding progress", position=2, leave=False): - row = dataframe.iloc[j] - try: - interval = row[metadata.name] - except Exception as e: - message = f'Data Asset Onboarding Error: Column "{metadata.name}" not found in the local file {path}.' - logging.exception(message) - raise Exception(message) - - # interval = validate_format(index=j, interval=interval, metadata=metadata) - - try: - interval = process_interval(interval, metadata.storage_type) - batch.append(interval) - except Exception as e: - logging.exception(e) - raise Exception(e) - - idx += 1 - if ((idx) % batch_size) == 0: - batch_index = str(len(files) + 1).zfill(8) - file_name = f"{folder}/{metadata.name}-{batch_index}.csv.gz" - - df = pd.DataFrame({metadata.name: batch}) - start, end = idx - len(batch), idx - df["@INDEX"] = range(start, end) - df.to_csv(file_name, compression="gzip", index=False) - s3_link = upload_data(file_name, content_type="text/csv", content_encoding="gzip") - files.append(File(path=s3_link, extension=FileType.CSV, compression="gzip")) - # get data column index - data_column_idx = df.columns.to_list().index(metadata.name) - batch = [] - - if len(batch) > 0: - batch_index = str(len(files) + 1).zfill(8) - file_name = f"{folder}/{metadata.name}-{batch_index}.csv.gz" - - df = pd.DataFrame({metadata.name: batch}) - start, end = idx - len(batch), idx - df["@INDEX"] = range(start, end) - df.to_csv(file_name, compression="gzip", index=False) - s3_link = upload_data(file_name, content_type="text/csv", content_encoding="gzip") - files.append(File(path=s3_link, extension=FileType.CSV, compression="gzip")) - # get data column index - data_column_idx = df.columns.to_list().index(metadata.name) - batch = [] - return files, data_column_idx, idx diff --git a/aixplain/processes/data_onboarding/process_media_files.py b/aixplain/processes/data_onboarding/process_media_files.py index 62fd369a..1e007d85 100644 --- a/aixplain/processes/data_onboarding/process_media_files.py +++ b/aixplain/processes/data_onboarding/process_media_files.py @@ -17,7 +17,6 @@ from pathlib import Path from tqdm import tqdm from typing import List, Tuple -from urllib.parse import urlparse AUDIO_MAX_SIZE = 50000000 IMAGE_TEXT_MAX_SIZE = 25000000 @@ -76,7 +75,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> row = dataframe.iloc[j] try: media_path = row[metadata.name] - except Exception as e: + except Exception: message = f'Data Asset Onboarding Error: Column "{metadata.name}" not found in the local file "{path}".' logging.exception(message) raise Exception(message) @@ -129,13 +128,13 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> if metadata.start_column is not None or metadata.end_column is not None: assert ( metadata.dsubtype != DataSubtype.INTERVAL - ), f"Data Asset Onboarding Error: Interval data types can not be cropped. Remove start and end columns." + ), "Data Asset Onboarding Error: Interval data types can not be cropped. Remove start and end columns." # adding ranges to crop the media if it is the case if metadata.start_column is not None: try: start_intervals.append(row[metadata.start_column]) - except Exception as e: + except Exception: message = f'Data Asset Onboarding Error: Column "{metadata.start_column}" not found.' logging.exception(message) raise Exception(message) @@ -143,7 +142,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> if metadata.end_column is not None: try: end_intervals.append(row[metadata.end_column]) - except Exception as e: + except Exception: message = f'Data Asset Onboarding Error: Column "{metadata.end_column}" not found.' logging.exception(message) raise Exception(message) @@ -167,7 +166,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> # compress the folder compressed_folder = compress_folder(data_file_name) # upload zipped medias into s3 - s3_compressed_folder = upload_data(compressed_folder, content_type="application/x-tar") + s3_compressed_folder = upload_data(compressed_folder, content_type="application/x-tar", return_s3_link=True) # update index files pointing the s3 link df["@SOURCE"] = s3_compressed_folder # remove media folder @@ -200,7 +199,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> end_column_idx = df.columns.to_list().index(end_column) df.to_csv(index_file_name, compression="gzip", index=False) - s3_link = upload_data(index_file_name, content_type="text/csv", content_encoding="gzip") + s3_link = upload_data(index_file_name, content_type="text/csv", content_encoding="gzip", return_s3_link=True) files.append(File(path=s3_link, extension=FileType.CSV, compression="gzip")) # get data column index data_column_idx = df.columns.to_list().index(metadata.name) @@ -225,7 +224,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> # compress the folder compressed_folder = compress_folder(data_file_name) # upload zipped medias into s3 - s3_compressed_folder = upload_data(compressed_folder, content_type="application/x-tar") + s3_compressed_folder = upload_data(compressed_folder, content_type="application/x-tar", return_s3_link=True) # update index files pointing the s3 link df["@SOURCE"] = s3_compressed_folder # remove media folder @@ -258,7 +257,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> end_column_idx = df.columns.to_list().index(end_column) df.to_csv(index_file_name, compression="gzip", index=False) - s3_link = upload_data(index_file_name, content_type="text/csv", content_encoding="gzip") + s3_link = upload_data(index_file_name, content_type="text/csv", content_encoding="gzip", return_s3_link=True) files.append(File(path=s3_link, extension=FileType.CSV, compression="gzip")) # get data column index data_column_idx = df.columns.to_list().index(metadata.name) diff --git a/aixplain/processes/data_onboarding/process_text_files.py b/aixplain/processes/data_onboarding/process_text_files.py index 1ba7f47e..84df057f 100644 --- a/aixplain/processes/data_onboarding/process_text_files.py +++ b/aixplain/processes/data_onboarding/process_text_files.py @@ -12,7 +12,7 @@ from aixplain.utils.file_utils import upload_data from pathlib import Path from tqdm import tqdm -from typing import List, Optional, Text, Tuple +from typing import List, Text, Tuple def process_text(content: str, storage_type: StorageType) -> Text: @@ -69,7 +69,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 1000) - path = paths[i] try: dataframe = pd.read_csv(path) - except Exception as e: + except Exception: message = f'Data Asset Onboarding Error: Local file "{path}" not found.' logging.exception(message) raise Exception(message) @@ -79,7 +79,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 1000) - row = dataframe.iloc[j] try: text_path = row[metadata.name] - except Exception as e: + except Exception: message = f'Data Asset Onboarding Error: Column "{metadata.name}" not found in the local file {path}.' logging.exception(message) raise Exception(message) @@ -100,7 +100,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 1000) - start, end = idx - len(batch), idx df["@INDEX"] = range(start, end) df.to_csv(file_name, compression="gzip", index=False) - s3_link = upload_data(file_name, content_type="text/csv", content_encoding="gzip") + s3_link = upload_data(file_name, content_type="text/csv", content_encoding="gzip", return_s3_link=True) files.append(File(path=s3_link, extension=FileType.CSV, compression="gzip")) # get data column index data_column_idx = df.columns.to_list().index(metadata.name) @@ -114,7 +114,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 1000) - start, end = idx - len(batch), idx df["@INDEX"] = range(start, end) df.to_csv(file_name, compression="gzip", index=False) - s3_link = upload_data(file_name, content_type="text/csv", content_encoding="gzip") + s3_link = upload_data(file_name, content_type="text/csv", content_encoding="gzip", return_s3_link=True) files.append(File(path=s3_link, extension=FileType.CSV, compression="gzip")) # get data column index data_column_idx = df.columns.to_list().index(metadata.name) diff --git a/aixplain/utils/file_utils.py b/aixplain/utils/file_utils.py index b1e16cf0..58781dfb 100644 --- a/aixplain/utils/file_utils.py +++ b/aixplain/utils/file_utils.py @@ -92,6 +92,7 @@ def upload_data( content_type: Text = "text/csv", content_encoding: Optional[Text] = None, nattempts: int = 2, + return_s3_link: bool = True, ): """Upload files to S3 with pre-signed URLs @@ -103,6 +104,7 @@ def upload_data( content_type (Text, optional): Type of content. Defaults to "text/csv". content_encoding (Text, optional): Content encoding. Defaults to None. nattempts (int, optional): Number of attempts for diminish the risk of exceptions. Defaults to 2. + return_s3_link (bool, optional): If True, the function will return the s3 link instead of the presigned url. Defaults to False. Reference: https://python.plainenglish.io/upload-files-to-aws-s3-using-pre-signed-urls-in-python-d3c2fcab1b41 @@ -143,9 +145,11 @@ def upload_data( return upload_data(file_name, content_type, content_encoding, nattempts - 1) else: raise Exception("File Uploading Error: Failure on Uploading to S3.") - bucket_name = re.findall(r"https://(.*?).s3.amazonaws.com", presigned_url)[0] - s3_link = f"s3://{bucket_name}/{path}" - return s3_link + if return_s3_link: + bucket_name = re.findall(r"https://(.*?).s3.amazonaws.com", presigned_url)[0] + s3_link = f"s3://{bucket_name}/{path}" + return s3_link + return presigned_url except Exception: if nattempts > 0: return upload_data(file_name, content_type, content_encoding, nattempts - 1) diff --git a/setup.cfg b/setup.cfg index 5792718a..337e05e4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,3 @@ [metadata] -description-file=README.md +description_file=README.md license_files=LICENSE.rst diff --git a/tests/conftest.py b/tests/conftest.py index a03eea30..a17177b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ from dotenv import load_dotenv # Load environment variables once for all tests -load_dotenv() +load_dotenv(override=True) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index cf018b8d..1a2cbcb4 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -386,6 +386,7 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): assert "eve" in str(response["data"]["output"]).lower() os.remove("ftest.db") + agent.delete() @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) @@ -417,7 +418,7 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert tool.tables == ["employees"] assert ( tool.schema - == 'CREATE TABLE test (\n "id" INTEGER, "name" TEXT, "department" TEXT, "salary" INTEGER\n )' # noqa: W503 + == 'CREATE TABLE employees (\n "id" INTEGER, "name" TEXT, "department" TEXT, "salary" INTEGER\n )' # noqa: W503 ) assert not tool.enable_commit # must be False by default @@ -456,3 +457,4 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): # Cleanup os.remove("test.csv") os.remove("test.db") + agent.delete() diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index d3d0082f..a5d97ae3 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -1,7 +1,10 @@ __author__ = "thiagocastroferreira" +import pytest +import os +import requests -from aixplain.enums import Function +from aixplain.enums import Function, EmbeddingModel from aixplain.factories import ModelFactory from aixplain.modules import LLM from datetime import datetime, timedelta, timezone @@ -55,7 +58,15 @@ def test_run_async(): assert "teste" in response["data"].lower() -def test_index_model(): +@pytest.mark.parametrize( + "embedding_model", + [ + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, id="Snowflake Arctic Embed M Long"), + pytest.param(EmbeddingModel.OPENAI_ADA002, id="OpenAI Ada 002"), + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, id="Snowflake Arctic Embed L v2.0"), + ], +) +def test_index_model(embedding_model): from uuid import uuid4 from aixplain.modules.model.record import Record from aixplain.factories import IndexFactory @@ -63,7 +74,7 @@ def test_index_model(): for index in IndexFactory.list()["results"]: index.delete() - index_model = IndexFactory.create(name=str(uuid4()), description=str(uuid4())) + index_model = IndexFactory.create(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) response = index_model.search("Hello") assert str(response.status) == "SUCCESS" @@ -74,6 +85,41 @@ def test_index_model(): assert str(response.status) == "SUCCESS" assert "aixplain" in response.data.lower() assert index_model.count() == 1 + index_model.upsert([Record(value="The world is great", value_type="text", uri="", id="2", attributes={})]) + assert index_model.count() == 2 + index_model.delete() + + +@pytest.mark.parametrize( + "embedding_model", + [ + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, id="Snowflake Arctic Embed M Long"), + pytest.param(EmbeddingModel.OPENAI_ADA002, id="OpenAI Ada 002"), + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, id="Snowflake Arctic Embed L v2.0"), + pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, id="Jina Clip v2 Multimodal"), + ], +) +def test_index_model_with_filter(embedding_model): + from uuid import uuid4 + from aixplain.modules.model.record import Record + from aixplain.factories import IndexFactory + from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + + for index in IndexFactory.list()["results"]: + index.delete() + + index_model = IndexFactory.create(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) + index_model.upsert([Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})]) + index_model.upsert( + [Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})] + ) + assert index_model.count() == 2 + response = index_model.search( + "", filters=[IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS)] + ) + assert str(response.status) == "SUCCESS" + assert "world" in response.data.lower() + assert len(response.details) == 1 index_model.delete() @@ -94,3 +140,58 @@ def test_llm_run_with_file(): # Verify response assert response["status"] == "SUCCESS" assert "🤖" in response["data"], "Robot emoji should be present in the response" + + +def test_index_model_with_image(): + from aixplain.factories import IndexFactory + from aixplain.modules.model.record import Record + from uuid import uuid4 + + for index in IndexFactory.list()["results"]: + index.delete() + + index_model = IndexFactory.create( + name=f"Image Index {uuid4()}", description="Index for images", embedding_model=EmbeddingModel.JINA_CLIP_V2_MULTIMODAL + ) + + records = [] + # Building image + records.append( + Record( + uri="https://aixplain-platform-assets.s3.us-east-1.amazonaws.com/samples/building.png", + value_type="image", + attributes={}, + ) + ) + + # beach image + image_url = "https://aixplain-platform-assets.s3.us-east-1.amazonaws.com/samples/hurricane.jpeg" + response = requests.get(image_url) + if response.status_code == 200: + with open("hurricane.jpeg", "wb") as f: + f.write(response.content) + records.append(Record(uri="hurricane.jpeg", value_type="image", attributes={})) + + # people image + image_url = "https://aixplain-platform-assets.s3.us-east-1.amazonaws.com/samples/faces.jpeg" + records.append(Record(uri=image_url, value_type="image", attributes={})) + + records.append(Record(value="Hello, world!", value_type="text", uri="", attributes={})) + + index_model.upsert(records) + + response = index_model.search("beach") + assert str(response.status) == "SUCCESS" + second_record = response.details[1]["metadata"]["uri"] + assert "hurricane" in second_record.lower() + + response = index_model.search("people") + assert str(response.status) == "SUCCESS" + first_record = response.details[0]["data"] + assert "hello" in first_record.lower() + second_record = response.details[1]["metadata"]["uri"] + assert "faces" in second_record.lower() + + assert index_model.count() == 4 + index_model.delete() + os.remove("hurricane.jpeg") diff --git a/tests/functional/model/run_utility_model_test.py b/tests/functional/model/run_utility_model_test.py index b9ef5465..4c05c554 100644 --- a/tests/functional/model/run_utility_model_test.py +++ b/tests/functional/model/run_utility_model_test.py @@ -1,6 +1,8 @@ from aixplain.factories import ModelFactory from aixplain.modules.model.utility_model import UtilityModelInput, utility_tool -from aixplain.enums import DataType +from aixplain.enums import DataType, AssetStatus +import pytest + def test_run_utility_model(): utility_model = None @@ -22,7 +24,6 @@ def test_run_utility_model(): assert utility_model.id is not None assert utility_model.inputs == inputs assert utility_model.output_examples == output_description - response = utility_model.run(data={"inputA": "test"}) assert response.status == "SUCCESS" assert response.data == "test" @@ -36,22 +37,24 @@ def test_run_utility_model(): if utility_model: utility_model.delete() + def test_utility_model_with_decorator(): utility_model = None try: + @utility_tool( - name="add_numbers_test name", - description="Adds two numbers together.", - inputs=[ + name="add_numbers_test name", + description="Adds two numbers together.", + inputs=[ UtilityModelInput(name="num1", type=DataType.NUMBER, description="The first number."), - UtilityModelInput(name="num2", type=DataType.NUMBER, description="The second number.") + UtilityModelInput(name="num2", type=DataType.NUMBER, description="The second number."), ], ) def add_numbers(num1: int, num2: int) -> int: return num1 + num2 utility_model = ModelFactory.create_utility_model(code=add_numbers) - + assert utility_model.id is not None assert len(utility_model.inputs) == 2 assert utility_model.inputs[0].name == "num1" @@ -64,16 +67,18 @@ def add_numbers(num1: int, num2: int) -> int: if utility_model: utility_model.delete() + def test_utility_model_string_concatenation(): utility_model = None try: + @utility_tool( name="concatenate_strings", description="Concatenates two strings.", inputs=[ UtilityModelInput(name="str1", type=DataType.TEXT, description="The first string."), UtilityModelInput(name="str2", type=DataType.TEXT, description="The second string."), - ] + ], ) def concatenate_strings(str1: str, str2: str) -> str: """Concatenates two strings and returns the result.""" @@ -96,10 +101,11 @@ def concatenate_strings(str1: str, str2: str) -> str: if utility_model: utility_model.delete() + def test_utility_model_code_as_string(): utility_model = None try: - code = f""" + code = """ @utility_tool( name="multiply_numbers", description="Multiply two numbers.", @@ -108,10 +114,7 @@ def multiply_numbers(int1: int, int2: int) -> int: \"\"\"Multiply two numbers and returns the result.\"\"\" return int1 * int2 """ - utility_model = ModelFactory.create_utility_model( - name="Multiply Numbers Test", - code=code - ) + utility_model = ModelFactory.create_utility_model(name="Multiply Numbers Test", code=code) assert utility_model.id is not None assert len(utility_model.inputs) == 2 @@ -123,13 +126,15 @@ def multiply_numbers(int1: int, int2: int) -> int: if utility_model: utility_model.delete() + def test_utility_model_simple_function(): utility_model = None try: + def test_string(input: str): """test string""" return input - + utility_model = ModelFactory.create_utility_model( name="String Model Test", code=test_string, @@ -145,3 +150,61 @@ def test_string(input: str): finally: if utility_model: utility_model.delete() + + +def test_utility_model_status(): + utility_model = None + try: + + def get_user_location(dummy_input: str, dummy_input2: str) -> str: + """Get user's city using dummy input""" + import requests + import json + + try: + response = requests.get("http://ip-api.com/json/") + response.raise_for_status() + data = response.json() + location = {"city": data["city"], "latitude": data["lat"], "longitude": data["lon"]} + return json.dumps(location) + except Exception as e: + return json.dumps({"error": str(e)}) + + utility_model = ModelFactory.create_utility_model( + name="Location Utility Test", + code=get_user_location, + ) + + # Test model creation + assert utility_model.id is not None + assert len(utility_model.inputs) == 2 + assert utility_model.inputs[0].name == "dummy_input" + assert utility_model.inputs[1].name == "dummy_input2" + assert utility_model.inputs[0].type == DataType.TEXT + assert utility_model.inputs[1].type == DataType.TEXT + + # Check initial status is DRAFT + assert utility_model.status == AssetStatus.DRAFT + + # deploy the model + utility_model.deploy() + + # Check status is now ONBOARDED + assert utility_model.status == AssetStatus.ONBOARDED + + # try reinitialize the model this should fail + # Second deployment attempt - should fail + utility_model_duplicate = ModelFactory.create_utility_model( + name="Location Utility Test", # Same name + code=get_user_location, + ) + + # Be more specific about the exception you're expecting + with pytest.raises(Exception, match=".*Utility name already exists*"): + utility_model_duplicate.deploy() + + finally: + if utility_model: + utility_model.delete() + if utility_model_duplicate: + utility_model_duplicate.delete() diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 2bd7d8a9..1946d8d8 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -23,6 +23,7 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier +from aixplain.modules.team_agent import InspectorTarget from copy import copy from uuid import uuid4 import pytest @@ -56,10 +57,8 @@ def run_input_map(request): return request.param -@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) -def test_end2end(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): - assert delete_agents_and_team_agents - +def create_agents_from_input_map(run_input_map, deploy=True): + """Helper function to create agents from input map""" agents = [] for agent in run_input_map["agents"]: tools = [] @@ -85,24 +84,90 @@ def test_end2end(run_input_map, delete_agents_and_team_agents, TeamAgentFactory) llm_id=agent["llm_id"], tools=tools, ) - agent.deploy() + if deploy: + agent.deploy() agents.append(agent) - team_agent = TeamAgentFactory.create( + return agents + + +def create_team_agent( + factory, agents, run_input_map, use_mentalist=True, use_inspector=True, num_inspectors=1, inspector_targets=None +): + """Helper function to create a team agent""" + if inspector_targets is None: + inspector_targets = [InspectorTarget.STEPS] + + team_agent = factory.create( name=run_input_map["team_agent_name"], agents=agents, description=run_input_map["team_agent_name"], llm_id=run_input_map["llm_id"], - use_mentalist_and_inspector=True, + use_mentalist=use_mentalist, + use_inspector=use_inspector, + num_inspectors=num_inspectors, + inspector_targets=inspector_targets, ) + return team_agent + + +def verify_inspector_steps(steps, num_inspectors): + """Helper function to verify inspector steps""" + # Count occurrences of each inspector + inspector_counts = {} + for i in range(num_inspectors): + inspector_name = f"inspector_{i}" + inspector_steps = [step for step in steps if inspector_name.lower() in step.get("agent", "").lower()] + inspector_counts[inspector_name] = len(inspector_steps) + + # Verify all inspectors are present and have the same number of steps + assert len(inspector_counts) == num_inspectors, f"Expected {num_inspectors} inspectors, found {len(inspector_counts)}" + + if len(inspector_counts) > 0: + first_count = next(iter(inspector_counts.values())) + for inspector, count in inspector_counts.items(): + assert count > 0, f"Inspector {inspector} has no steps" + assert count == first_count, f"Inspector {inspector} has {count} steps, expected {first_count}" + print(f"Inspector {inspector} has {count} steps") + + return inspector_counts + + +def verify_response_generator(steps, has_output_target=False): + """Helper function to verify response generator step""" + response_generator_steps = [step for step in steps if "response_generator" in step.get("agent", "").lower()] + assert ( + len(response_generator_steps) == 1 + ), f"Expected exactly one response_generator step, found {len(response_generator_steps)}" + + response_generator_step = response_generator_steps[0] + + if has_output_target: + assert response_generator_step[ + "thought" + ], "Response generator thought is empty, but should contain inspector feedback because OUTPUT is in inspector_targets" + print(f"Response generator thought with OUTPUT target: {response_generator_step['thought']}") + + return response_generator_step + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_end2end(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + team_agent = create_team_agent(TeamAgentFactory, agents, run_input_map, use_mentalist=True, use_inspector=True) + assert team_agent is not None assert team_agent.status == AssetStatus.DRAFT + # deploy team agent team_agent.deploy() team_agent = TeamAgentFactory.get(team_agent.id) assert team_agent is not None assert team_agent.status == AssetStatus.ONBOARDED + response = team_agent.run(data=run_input_map["query"]) assert response is not None @@ -122,40 +187,8 @@ def test_draft_team_agent_update(run_input_map, TeamAgentFactory): for agent in AgentFactory.list()["results"]: agent.delete() - agents = [] - for agent in run_input_map["agents"]: - tools = [] - if "model_tools" in agent: - for tool in agent["model_tools"]: - tool_ = copy(tool) - for supplier in Supplier: - if tool["supplier"] is not None and tool["supplier"].lower() in [ - supplier.value["code"].lower(), - supplier.value["name"].lower(), - ]: - tool_["supplier"] = supplier - break - tools.append(AgentFactory.create_model_tool(**tool_)) - if "pipeline_tools" in agent: - for tool in agent["pipeline_tools"]: - tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"])) - - agent = AgentFactory.create( - name=agent["agent_name"], - description=agent["agent_name"], - instructions=agent["agent_name"], - llm_id=agent["llm_id"], - tools=tools, - ) - agents.append(agent) - - team_agent = TeamAgentFactory.create( - name=run_input_map["team_agent_name"], - agents=agents, - description=run_input_map["team_agent_name"], - llm_id=run_input_map["llm_id"], - use_mentalist_and_inspector=True, - ) + agents = create_agents_from_input_map(run_input_map, deploy=False) + team_agent = create_team_agent(TeamAgentFactory, agents, run_input_map, use_mentalist=True, use_inspector=True) team_agent_name = str(uuid4()).replace("-", "") team_agent.name = team_agent_name @@ -172,38 +205,12 @@ def test_fail_non_existent_llm(run_input_map, TeamAgentFactory): for agent in AgentFactory.list()["results"]: agent.delete() - agents = [] - for agent in run_input_map["agents"]: - tools = [] - if "model_tools" in agent: - for tool in agent["model_tools"]: - tool_ = copy(tool) - for supplier in Supplier: - if tool["supplier"] is not None and tool["supplier"].lower() in [ - supplier.value["code"].lower(), - supplier.value["name"].lower(), - ]: - tool_["supplier"] = supplier - break - tools.append(AgentFactory.create_model_tool(**tool_)) - if "pipeline_tools" in agent: - for tool in agent["pipeline_tools"]: - tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"])) - - agent = AgentFactory.create( - name=agent["agent_name"], - description=agent["agent_name"], - instructions=agent["agent_name"], - llm_id=agent["llm_id"], - tools=tools, - ) - agents.append(agent) + agents = create_agents_from_input_map(run_input_map, deploy=False) with pytest.raises(Exception) as exc_info: TeamAgentFactory.create( name="Non Existent LLM", description="", - instructions="", llm_id="non_existent_llm", agents=agents, ) @@ -214,40 +221,8 @@ def test_fail_non_existent_llm(run_input_map, TeamAgentFactory): def test_add_remove_agents_from_team_agent(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): assert delete_agents_and_team_agents - agents = [] - for agent in run_input_map["agents"]: - tools = [] - if "model_tools" in agent: - for tool in agent["model_tools"]: - tool_ = copy(tool) - for supplier in Supplier: - if tool["supplier"] is not None and tool["supplier"].lower() in [ - supplier.value["code"].lower(), - supplier.value["name"].lower(), - ]: - tool_["supplier"] = supplier - break - tools.append(AgentFactory.create_model_tool(**tool_)) - if "pipeline_tools" in agent: - for tool in agent["pipeline_tools"]: - tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"])) - - agent = AgentFactory.create( - name=agent["agent_name"], - description=agent["agent_name"], - instructions=agent["agent_name"], - llm_id=agent["llm_id"], - tools=tools, - ) - agents.append(agent) - - team_agent = TeamAgentFactory.create( - name=run_input_map["team_agent_name"], - agents=agents, - description=run_input_map["team_agent_name"], - llm_id=run_input_map["llm_id"], - use_mentalist_and_inspector=True, - ) + agents = create_agents_from_input_map(run_input_map, deploy=False) + team_agent = create_team_agent(TeamAgentFactory, agents, run_input_map, use_mentalist=True, use_inspector=True) assert team_agent is not None assert team_agent.status == AssetStatus.DRAFT @@ -350,7 +325,8 @@ def test_team_agent_with_parameterized_agents(delete_agents_and_team_agents): agents=[search_agent, translation_agent], description="Team agent with parameterized tools", llm_id="677c16166eb563bb611623c1", - use_mentalist_and_inspector=True, + use_mentalist=True, + use_inspector=True, ) # Deploy team agent @@ -374,3 +350,250 @@ def test_team_agent_with_parameterized_agents(delete_agents_and_team_agents): team_agent.delete() search_agent.delete() translation_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_inspector_params(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with custom inspector parameters""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create team agent with custom inspector parameters + num_inspectors = 2 + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + use_inspector=True, + num_inspectors=num_inspectors, + inspector_targets=["steps", "output"], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + assert team_agent.use_mentalist is True + assert team_agent.use_inspector is True + assert team_agent.max_inspectors == num_inspectors + assert len(team_agent.inspector_targets) == 2 + assert InspectorTarget.STEPS in team_agent.inspector_targets + assert InspectorTarget.OUTPUT in team_agent.inspector_targets + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + assert team_agent.max_inspectors == num_inspectors + assert len(team_agent.inspector_targets) == 2 + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "data" in response + assert response["data"]["session_id"] is not None + assert response["data"]["output"] is not None + + # Check if intermediate steps contain inspector outputs + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + + # Verify inspector steps + verify_inspector_steps(steps, num_inspectors) + + # Verify response generator + verify_response_generator(steps, has_output_target=True) + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_update_inspector_params(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test updating inspector parameters for a team agent""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create team agent with initial inspector parameters + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + use_inspector=True, + num_inspectors=1, + inspector_targets=["steps"], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + assert team_agent.max_inspectors == 1 + assert len(team_agent.inspector_targets) == 1 + assert team_agent.inspector_targets[0] == InspectorTarget.STEPS + + # Update inspector parameters + team_agent.max_inspectors = 3 + team_agent.inspector_targets = [InspectorTarget.STEPS, InspectorTarget.OUTPUT] + team_agent.update() + + # Get the updated team agent + updated_team_agent = TeamAgentFactory.get(team_agent.id) + assert updated_team_agent.max_inspectors == 3 + assert len(updated_team_agent.inspector_targets) == 2 + assert InspectorTarget.STEPS in updated_team_agent.inspector_targets + assert InspectorTarget.OUTPUT in updated_team_agent.inspector_targets + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_steps_only_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with inspector targeting only steps""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create team agent with steps-only inspector + num_inspectors = 1 + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + use_inspector=True, + num_inspectors=num_inspectors, + inspector_targets=["steps"], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + assert team_agent.max_inspectors == num_inspectors + assert len(team_agent.inspector_targets) == 1 + assert team_agent.inspector_targets[0] == InspectorTarget.STEPS + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + + # Verify inspector steps + verify_inspector_steps(steps, num_inspectors) + + # Verify response generator + response_generator_step = verify_response_generator(steps, has_output_target=False) + print(f"Response generator thought (STEPS only): {response_generator_step.get('thought', '')}") + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_output_only_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with inspector targeting only output""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create team agent with output-only inspector + num_inspectors = 1 + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + use_inspector=True, + num_inspectors=num_inspectors, + inspector_targets=["output"], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + assert team_agent.max_inspectors == num_inspectors + assert len(team_agent.inspector_targets) == 1 + assert team_agent.inspector_targets[0] == InspectorTarget.OUTPUT + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + + # Verify response generator with OUTPUT target + verify_response_generator(steps, has_output_target=True) + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_multiple_inspectors(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with multiple inspectors""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create team agent with multiple inspectors + num_inspectors = 5 # Testing with 5 inspectors + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + use_inspector=True, + num_inspectors=num_inspectors, + inspector_targets=["steps"], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + assert team_agent.max_inspectors == num_inspectors + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + + # Verify inspector steps + verify_inspector_steps(steps, num_inspectors) + + # Verify response generator + verify_response_generator(steps, has_output_target=False) + + team_agent.delete() diff --git a/tests/unit/agent/model_tool_test.py b/tests/unit/agent/model_tool_test.py index bb849d8f..6c5cf8fb 100644 --- a/tests/unit/agent/model_tool_test.py +++ b/tests/unit/agent/model_tool_test.py @@ -96,6 +96,7 @@ def test_to_dict(mock_model, mock_model_factory): expected = { "function": mock_model.function.value, "type": "model", + "name": "", "description": "Test description", "supplier": mock_model.supplier.value["code"], "version": None, @@ -175,3 +176,21 @@ def test_validate_parameters(mock_model, params, expected_result, error_expected else: result = tool.validate_parameters(params) assert result == expected_result + + +@pytest.mark.parametrize( + "tool_name,expected_name", + [ + ("custom_tool", "custom_tool"), + ("", ""), # Test empty name + ("translation_model", "translation_model"), + (None, ""), # Test None value should default to empty string + ], +) +def test_tool_name(mock_model, mock_model_factory, tool_name, expected_name): + mock_model_factory.get.return_value = mock_model + tool = ModelTool(model="test_model_id", name=tool_name, function=Function.TRANSLATION) + assert tool.name == expected_name + # Verify name appears correctly in dictionary representation + tool_dict = tool.to_dict() + assert tool_dict["name"] == expected_name diff --git a/tests/unit/agent/sql_tool_test.py b/tests/unit/agent/sql_tool_test.py index f6b6fa87..12170a23 100644 --- a/tests/unit/agent/sql_tool_test.py +++ b/tests/unit/agent/sql_tool_test.py @@ -2,6 +2,8 @@ import pytest import pandas as pd from aixplain.factories import AgentFactory +from aixplain.enums import DatabaseSourceType + from aixplain.modules.agent.tool.sql_tool import ( SQLTool, create_database_from_csv, @@ -310,3 +312,27 @@ def test_sql_tool_schema_inference(tmp_path): # Clean up the database file if os.path.exists(tool.database): os.remove(tool.database) + + +def test_create_sql_tool_source_type_handling(tmp_path): + # Create a test database file + db_path = os.path.join(tmp_path, "test.db") + import sqlite3 + + conn = sqlite3.connect(db_path) + conn.execute("CREATE TABLE test (id INTEGER, name TEXT)") + conn.close() + + # Test with string input + tool_str = AgentFactory.create_sql_tool(description="Test", source=db_path, source_type="sqlite", schema="test") + assert isinstance(tool_str, SQLTool) + + # Test with enum input + tool_enum = AgentFactory.create_sql_tool( + description="Test", source=db_path, source_type=DatabaseSourceType.SQLITE, schema="test" + ) + assert isinstance(tool_enum, SQLTool) + + # Test invalid type + with pytest.raises(SQLToolError, match="Source type must be either a string or DatabaseSourceType enum, got "): + AgentFactory.create_sql_tool(description="Test", source=db_path, source_type=123, schema="test") # Invalid type \ No newline at end of file diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index dbf698cc..7349bc21 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -1,18 +1,19 @@ import requests_mock -from aixplain.enums import Function, ResponseStatus +from aixplain.enums import DataType, Function, ResponseStatus, StorageType, EmbeddingModel from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse from aixplain.modules.model.index_model import IndexModel from aixplain.utils import config import logging - +import pytest data = {"data": "Model Index", "description": "This is a dummy collection for testing."} index_id = "id" execute_url = f"{config.MODELS_RUN_URL}/{index_id}".replace("/api/v1/execute", "/api/v2/execute") -def test_search_success(): +def test_text_search_success(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.TEXT) mock_response = {"status": "SUCCESS"} with requests_mock.Mocker() as mock: @@ -24,7 +25,30 @@ def test_search_success(): assert response.status == ResponseStatus.SUCCESS -def test_add_success(): +def test_image_search_success(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) + mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) + mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") + + mock_response = {"status": "SUCCESS"} + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=mock_response, status_code=200) + index_model = IndexModel( + id=index_id, + data=data, + name="name", + function=Function.SEARCH, + embedding_model=EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, + ) + response = index_model.search("test.jpg") + + assert isinstance(response, ModelResponse) + assert response.status == ResponseStatus.SUCCESS + + +def test_text_add_success(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) mock_response = {"status": "SUCCESS"} mock_documents = [ @@ -43,7 +67,27 @@ def test_add_success(): assert response.status == ResponseStatus.SUCCESS -def test_update_success(): +def test_image_add_success(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.FILE] * 4) + mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) + mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") + mock_response = {"status": "SUCCESS"} + + mock_documents = [ + Record(uri="https://example.com/test.jpg", value_type="image", id=0, attributes={}), + ] + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=mock_response, status_code=200) + index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) + response = index_model.upsert(mock_documents) + + assert isinstance(response, ModelResponse) + assert response.status == ResponseStatus.SUCCESS + + +def test_text_update_success(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) mock_response = {"status": "SUCCESS"} mock_documents = [ @@ -76,3 +120,66 @@ def test_count_success(): assert isinstance(response, int) assert response == 4 + + +def test_validate_record_success(mocker): + mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) + mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) + mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") + + record = Record(uri="test.jpg", value_type="image", id=0, attributes={}) + record.validate() + assert record.value_type == DataType.IMAGE + assert record.uri == "https://example.com/test.jpg" + assert record.value == "" + + +def test_validate_record_failure(mocker): + mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=False) + mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) + mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") + record = Record(uri="test.mov", value_type="video", id=0, attributes={}) + with pytest.raises(Exception) as e: + record.validate() + assert str(e.value) == "Index Upsert Error: Invalid value type" + + +def test_validate_record_failure_no_uri(mocker): + record = Record(value="test.jpg", value_type="image", id=0, uri="", attributes={}) + with pytest.raises(Exception) as e: + record.validate() + assert str(e.value) == "Index Upsert Error: URI is required for image records" + + +def test_validate_record_failure_no_value(mocker): + record = Record(uri="test.jpg", value_type="text", id=0, attributes={}) + with pytest.raises(Exception) as e: + record.validate() + assert str(e.value) == "Index Upsert Error: Value is required for text records" + + +def test_record_to_dict(): + record = Record(value="test", value_type=DataType.TEXT, id=0, uri="", attributes={}) + record_dict = record.to_dict() + assert record_dict["dataType"] == "text" + assert record_dict["uri"] == "" + assert record_dict["data"] == "test" + assert record_dict["document_id"] == 0 + assert record_dict["attributes"] == {} + + record = Record(value="test", value_type=DataType.IMAGE, id=0, uri="https://example.com/test.jpg", attributes={}) + record_dict = record.to_dict() + assert record_dict["dataType"] == "image" + assert record_dict["uri"] == "https://example.com/test.jpg" + assert record_dict["data"] == "test" + assert record_dict["document_id"] == 0 + assert record_dict["attributes"] == {} + + +def test_index_filter(): + from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + + filter = IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS) + assert filter.field == "category" + assert filter.value == "world" + assert filter.operator == IndexFilterOperator.EQUALS diff --git a/tests/unit/pipeline_test.py b/tests/unit/pipeline_test.py index 183b80d5..7df95691 100644 --- a/tests/unit/pipeline_test.py +++ b/tests/unit/pipeline_test.py @@ -35,8 +35,12 @@ def test_create_pipeline(): headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} ref_response = {"id": "12345"} mock.post(url, headers=headers, json=ref_response) - ref_pipeline = Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) - hyp_pipeline = PipelineFactory.create(pipeline={"nodes": []}, name="Pipeline Test") + ref_pipeline = Pipeline( + id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY + ) + hyp_pipeline = PipelineFactory.create( + pipeline={"nodes": []}, name="Pipeline Test" + ) assert hyp_pipeline.id == ref_pipeline.id assert hyp_pipeline.name == ref_pipeline.name @@ -103,9 +107,14 @@ def test_list_pipelines_error_response(): mock.post(url, headers=headers, json=error_response, status_code=400) with pytest.raises(Exception) as excinfo: - PipelineFactory.list(query=query, page_number=page_number, page_size=page_size) + PipelineFactory.list( + query=query, page_number=page_number, page_size=page_size + ) - assert "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" in str(excinfo.value) + assert ( + "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" + in str(excinfo.value) + ) def test_get_pipeline_error_response(): @@ -123,7 +132,112 @@ def test_get_pipeline_error_response(): with pytest.raises(Exception) as excinfo: PipelineFactory.get(pipeline_id=pipeline_id) - assert "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" in str(excinfo.value) + assert ( + "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" + in str(excinfo.value) + ) + + +@pytest.fixture +def mock_pipeline(): + return Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) + + +def test_run_async_success(mock_pipeline): + with requests_mock.Mocker() as mock: + execute_url = urljoin( + config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}" + ) + success_response = PipelineResponse( + status=ResponseStatus.SUCCESS, url=execute_url + ) + mock.post(execute_url, json=success_response.__dict__, status_code=200) + + response = mock_pipeline.run_async(data="input_data") + + assert isinstance(response, PipelineResponse) + assert response.status == ResponseStatus.SUCCESS + + +def test_run_sync_success(mock_pipeline): + with requests_mock.Mocker() as mock: + poll_url = urljoin( + config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}" + ) + execute_url = urljoin( + config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}" + ) + success_response = PipelineResponse(status=ResponseStatus.SUCCESS, url=poll_url) + poll_response = PipelineResponse( + status=ResponseStatus.SUCCESS, data={"output": "poll_result"} + ) + mock.post(execute_url, json=success_response.__dict__, status_code=200) + mock.get(poll_url, json=poll_response.__dict__, status_code=200) + response = mock_pipeline.run(data="input_data") + + assert isinstance(response, PipelineResponse) + assert response.status == ResponseStatus.SUCCESS + + +def test_poll_success(mock_pipeline): + with requests_mock.Mocker() as mock: + poll_url = urljoin( + config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}" + ) + poll_response = PipelineResponse( + status=ResponseStatus.SUCCESS, data={"output": "poll_result"} + ) + mock.get(poll_url, json=poll_response.__dict__, status_code=200) + + response = mock_pipeline.poll(poll_url=poll_url) + + assert isinstance(response, PipelineResponse) + assert response.status == ResponseStatus.SUCCESS + assert response.data["output"] == "poll_result" + + +@pytest.fixture +def mock_pipeline(): + return Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) + + +def test_run_async_success(mock_pipeline): + with requests_mock.Mocker() as mock: + execute_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}") + success_response = PipelineResponse(status=ResponseStatus.SUCCESS, url=execute_url) + mock.post(execute_url, json=success_response.__dict__, status_code=200) + + response = mock_pipeline.run_async(data="input_data") + + assert isinstance(response, PipelineResponse) + assert response.status == ResponseStatus.SUCCESS + + +def test_run_sync_success(mock_pipeline): + with requests_mock.Mocker() as mock: + poll_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}") + execute_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}") + success_response = {"status": "SUCCESS", "url": poll_url, "completed": True} + poll_response = {"status": "SUCCESS", "data": {"output": "poll_result"}, "completed": True} + mock.post(execute_url, json=success_response, status_code=200) + mock.get(poll_url, json=poll_response, status_code=200) + response = mock_pipeline.run(data="input_data") + + assert isinstance(response, PipelineResponse) + assert response.status == ResponseStatus.SUCCESS + + +def test_poll_success(mock_pipeline): + with requests_mock.Mocker() as mock: + poll_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}") + poll_response = PipelineResponse(status=ResponseStatus.SUCCESS, data={"output": "poll_result"}) + mock.get(poll_url, json=poll_response.__dict__, status_code=200) + + response = mock_pipeline.poll(poll_url=poll_url) + + assert isinstance(response, PipelineResponse) + assert response.status == ResponseStatus.SUCCESS + assert response.data["output"] == "poll_result" @pytest.fixture diff --git a/tests/unit/team_agent_test.py b/tests/unit/team_agent_test.py index a456b980..2b06043e 100644 --- a/tests/unit/team_agent_test.py +++ b/tests/unit/team_agent_test.py @@ -1,13 +1,15 @@ import pytest import requests_mock +from urllib.parse import urljoin +from unittest.mock import patch + +from aixplain.enums.asset_status import AssetStatus from aixplain.factories import TeamAgentFactory from aixplain.factories import AgentFactory -from aixplain.enums.asset_status import AssetStatus -from aixplain.modules import Agent, TeamAgent +from aixplain.modules.agent import Agent +from aixplain.modules.team_agent import TeamAgent, InspectorTarget from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.utils import config -from urllib.parse import urljoin -from unittest.mock import patch def test_fail_no_data_query(): @@ -108,6 +110,41 @@ def test_to_dict(): assert team_agent_dict["agents"][0]["label"] == "AGENT" +def test_to_dict_with_inspector_params(): + team_agent = TeamAgent( + id="123", + name="Test Team Agent(-)", + agents=[ + Agent( + id="", + name="Test Agent(-)", + description="Test Agent Description", + instructions="Test Agent Role", + llm_id="6646261c6eb563165658bbb1", + tools=[ModelTool(function="text-generation")], + ) + ], + description="Test Team Agent Description", + llm_id="6646261c6eb563165658bbb1", + use_mentalist=True, + use_inspector=True, + max_inspectors=2, + inspector_targets=[InspectorTarget.STEPS, InspectorTarget.OUTPUT], + ) + + team_agent_dict = team_agent.to_dict() + assert team_agent_dict["id"] == "123" + assert team_agent_dict["name"] == "Test Team Agent(-)" + assert team_agent_dict["description"] == "Test Team Agent Description" + assert team_agent_dict["llmId"] == "6646261c6eb563165658bbb1" + assert team_agent_dict["supervisorId"] == "6646261c6eb563165658bbb1" + assert team_agent_dict["plannerId"] == "6646261c6eb563165658bbb1" + assert team_agent_dict["inspectorId"] == "6646261c6eb563165658bbb1" + assert team_agent_dict["maxInspectors"] == 2 + assert team_agent_dict["inspectorTargets"] == ["steps", "output"] + assert len(team_agent_dict["agents"]) == 1 + + @patch("aixplain.factories.model_factory.ModelFactory.get") def test_create_team_agent(mock_model_factory_get): from aixplain.modules import Model @@ -193,11 +230,11 @@ def test_create_team_agent(mock_model_factory_get): team_agent = TeamAgentFactory.create( name="TEST Multi agent(-)", + agents=[agent], + llm_id="6646261c6eb563165658bbb1", description="TEST Multi agent", use_mentalist=True, use_inspector=True, - llm_id="6646261c6eb563165658bbb1", - agents=[agent], ) assert team_agent.id is not None assert team_agent.name == team_ref_response["name"] @@ -232,6 +269,179 @@ def test_create_team_agent(mock_model_factory_get): assert team_agent.status.value == "onboarded" +@patch("aixplain.factories.model_factory.ModelFactory.get") +def test_create_team_agent_with_inspector_params(mock_model_factory_get): + from aixplain.modules import Model + from aixplain.enums import Function + + # Mock the model factory response + mock_model = Model( + id="6646261c6eb563165658bbb1", name="Test LLM", description="Test LLM Description", function=Function.TEXT_GENERATION + ) + mock_model_factory_get.return_value = mock_model + + with requests_mock.Mocker() as mock: + headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + # MOCK GET LLM + url = urljoin(config.BACKEND_URL, "sdk/models/6646261c6eb563165658bbb1") + model_ref_response = { + "id": "6646261c6eb563165658bbb1", + "name": "Test LLM", + "description": "Test LLM Description", + "function": {"id": "text-generation"}, + "supplier": "openai", + "version": {"id": "1.0"}, + "status": "onboarded", + "pricing": {"currency": "USD", "value": 0.0}, + } + mock.get(url, headers=headers, json=model_ref_response) + + # AGENT MOCK CREATION + url = urljoin(config.BACKEND_URL, "sdk/agents") + ref_response = { + "id": "123", + "name": "Test Agent(-)", + "description": "Test Agent Description", + "role": "Test Agent Role", + "teamId": "123", + "version": "1.0", + "status": "draft", + "llmId": "6646261c6eb563165658bbb1", + "pricing": {"currency": "USD", "value": 0.0}, + "assets": [ + { + "type": "model", + "supplier": "openai", + "version": "1.0", + "assetId": "6646261c6eb563165658bbb1", + "function": "text-generation", + } + ], + } + mock.post(url, headers=headers, json=ref_response) + + agent = AgentFactory.create( + name="Test Agent(-)", + description="Test Agent Description", + instructions="Test Agent Role", + llm_id="6646261c6eb563165658bbb1", + tools=[ModelTool(model="6646261c6eb563165658bbb1")], + ) + + # AGENT MOCK GET + url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}") + mock.get(url, headers=headers, json=ref_response) + + # TEAM MOCK CREATION + url = urljoin(config.BACKEND_URL, "sdk/agent-communities") + team_ref_response = { + "id": "team_agent_123", + "name": "TEST Multi agent(-)", + "status": "draft", + "teamId": 645, + "description": "TEST Multi agent", + "llmId": "6646261c6eb563165658bbb1", + "assets": [], + "agents": [{"assetId": "123", "type": "AGENT", "number": 0, "label": "AGENT"}], + "links": [], + "plannerId": "6646261c6eb563165658bbb1", + "inspectorId": "6646261c6eb563165658bbb1", + "supervisorId": "6646261c6eb563165658bbb1", + "maxInspectors": 3, + "inspectorTargets": ["steps", "output"], + "createdAt": "2024-10-28T19:30:25.344Z", + "updatedAt": "2024-10-28T19:30:25.344Z", + } + mock.post(url, headers=headers, json=team_ref_response) + + team_agent = TeamAgentFactory.create( + name="TEST Multi agent(-)", + agents=[agent], + llm_id="6646261c6eb563165658bbb1", + description="TEST Multi agent", + use_mentalist=True, + use_inspector=True, + num_inspectors=3, + inspector_targets=[InspectorTarget.STEPS, InspectorTarget.OUTPUT], + ) + assert team_agent.id is not None + assert team_agent.name == team_ref_response["name"] + assert team_agent.description == team_ref_response["description"] + assert team_agent.llm_id == team_ref_response["llmId"] + assert team_agent.use_mentalist is True + assert team_agent.use_inspector is True + assert team_agent.max_inspectors == 3 + assert team_agent.inspector_targets == [InspectorTarget.STEPS, InspectorTarget.OUTPUT] + assert team_agent.status == AssetStatus.DRAFT + assert len(team_agent.agents) == 1 + assert team_agent.agents[0].id == team_ref_response["agents"][0]["assetId"] + + +def test_fail_inspector_without_mentalist(): + with pytest.raises(Exception) as exc_info: + TeamAgentFactory.create( + name="Test Team Agent(-)", + agents=[ + Agent( + id="123", + name="Test Agent(-)", + description="Test Agent Description", + instructions="Test Agent Role", + llm_id="6646261c6eb563165658bbb1", + tools=[ModelTool(function="text-generation")], + ) + ], + use_mentalist=False, + use_inspector=True, + ) + + assert "you must enable Mentalist" in str(exc_info.value) + + +def test_fail_invalid_inspector_target(): + with pytest.raises(ValueError) as exc_info: + TeamAgentFactory.create( + name="Test Team Agent(-)", + agents=[ + Agent( + id="123", + name="Test Agent(-)", + description="Test Agent Description", + instructions="Test Agent Role", + llm_id="6646261c6eb563165658bbb1", + tools=[ModelTool(function="text-generation")], + ) + ], + use_mentalist=True, + use_inspector=True, + inspector_targets=["invalid_target"], + ) + + assert "Invalid inspector target" in str(exc_info.value) + + +def test_fail_zero_inspectors(): + with pytest.raises(Exception) as exc_info: + TeamAgentFactory.create( + name="Test Team Agent(-)", + agents=[ + Agent( + id="123", + name="Test Agent(-)", + description="Test Agent Description", + instructions="Test Agent Role", + llm_id="6646261c6eb563165658bbb1", + tools=[ModelTool(function="text-generation")], + ) + ], + use_mentalist=True, + use_inspector=True, + num_inspectors=0, + ) + + assert "The number of inspectors must be greater than 0" in str(exc_info.value) + + def test_build_team_agent(mocker): from aixplain.factories.team_agent_factory.utils import build_team_agent from aixplain.modules.agent import Agent, AgentTask diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py index 305c6a52..e678fa7a 100644 --- a/tests/unit/utility_test.py +++ b/tests/unit/utility_test.py @@ -6,8 +6,9 @@ from aixplain.enums import DataType, Function from aixplain.enums.asset_status import AssetStatus from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput -from aixplain.modules.model.utils import parse_code +from aixplain.modules.model.utils import parse_code, parse_code_decorated from unittest.mock import patch +import warnings def test_utility_model(): @@ -25,7 +26,9 @@ def test_utility_model(): assert utility_model.name == "utility_model_test" assert utility_model.description == "utility_model_test" assert utility_model.code == "utility_model_test" - assert utility_model.inputs == [UtilityModelInput(name="input_string", description="The input_string input is a text", type=DataType.TEXT)] + assert utility_model.inputs == [ + UtilityModelInput(name="input_string", description="The input_string input is a text", type=DataType.TEXT) + ] assert utility_model.output_examples == "output_description" @@ -81,14 +84,20 @@ def test_utility_model_to_dict(): "code": "utility_model_test", "function": "utilities", "outputDescription": "output_description", - "status": AssetStatus.ONBOARDED.value, + "status": AssetStatus.DRAFT.value, } def test_update_utility_model(): with requests_mock.Mocker() as mock: - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): with patch( "aixplain.modules.model.utils.parse_code", return_value=( @@ -116,7 +125,7 @@ def test_update_utility_model(): with pytest.warns( DeprecationWarning, - match="update\(\) is deprecated and will be removed in a future version. Please use save\(\) instead.", + match=r"update\(\) is deprecated and will be removed in a future version. Please use save\(\) instead.", ): utility_model.description = "updated_description" utility_model.update() @@ -127,8 +136,14 @@ def test_update_utility_model(): def test_save_utility_model(): with requests_mock.Mocker() as mock: - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): with patch( "aixplain.modules.model.utils.parse_code", return_value=( @@ -154,8 +169,6 @@ def test_save_utility_model(): api_key=config.TEAM_API_KEY, ) - import warnings - # it should not trigger any warning with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # Trigger all warnings @@ -170,8 +183,14 @@ def test_save_utility_model(): def test_delete_utility_model(): with requests_mock.Mocker() as mock: - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): mock.delete(urljoin(config.BACKEND_URL, "sdk/utilities/123"), status_code=200, json={"id": "123"}) utility_model = UtilityModel( id="123", @@ -199,6 +218,7 @@ def test_parse_code(): assert description == "" assert code_link == "code_link" assert name == "main" + # Code is a function def main(a: int, b: int): """ @@ -241,8 +261,14 @@ def main(originCode): def test_validate_new_model(): """Test validation for a new model""" - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): # Test with valid inputs utility_model = UtilityModel( id="", # Empty ID for new model @@ -348,3 +374,136 @@ def test_model_exists_empty_id(): api_key=config.TEAM_API_KEY, ) assert utility_model._model_exists() is False + + +def test_utility_model_with_return_annotation(): + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): + + def get_location(input_str: str) -> str: + """ + Get location information + + Args: + input_str (str): Input string parameter + Returns: + str: Location information + """ + return input_str + + utility_model = UtilityModel( + id="123", + name="location_test", + description="Get location information", + code=get_location, + output_examples="Location data example", + inputs=[UtilityModelInput(name="input_str", description="Input string parameter", type=DataType.TEXT)], + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + + # Verify the model is created correctly with the return type annotation + assert utility_model.id == "123" + assert utility_model.name == "location_test" + assert utility_model.description == "Get location information" + assert len(utility_model.inputs) == 1 + assert utility_model.inputs[0].name == "input_str" + assert utility_model.inputs[0].type == DataType.TEXT + assert utility_model.inputs[0].description == "Input string parameter" + + # Verify the function parameters are parsed correctly + code, inputs, description, name = parse_code_decorated(get_location) + assert len(inputs) == 1 + assert inputs[0].name == "input_str" + assert inputs[0].type == DataType.TEXT + assert "Get location information" in description + assert name == "get_location" + + +def test_parse_code_with_class(): + """Test that parsing code with a class raises proper error""" + + class DummyModel: + def __init__(self): + pass + + # Test with class + with pytest.raises( + TypeError, + match=r"Code must be either a string or a callable function, not a class or class instance\. You tried to pass a class or class instance: <.*\.DummyModel object at 0x[0-9a-f]+>", + ): + parse_code_decorated(DummyModel()) + + # Test with class instance + with pytest.raises( + TypeError, + match=r"Code must be either a string or a callable function, not a class or class instance\. You tried to pass a class or class instance: <.*\.DummyModel object at 0x[0-9a-f]+>", + ): + parse_code_decorated(DummyModel()) + + +def test_utility_model_creation_warning(): + """Test that appropriate warnings are shown during utility model creation and validation""" + with requests_mock.Mocker() as mock: + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="s3://bucket/path/to/code"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="s3://bucket/path/to/code"): + # Mock the model creation + model_id = "123" + mock.post(urljoin(config.BACKEND_URL, "sdk/utilities"), json={"id": model_id}) + + # Mock the model existence check + mock.get(urljoin(config.BACKEND_URL, f"sdk/models/{model_id}"), status_code=200) + + # Create the utility model and check for warning during creation + with pytest.warns(UserWarning, match="WARNING: Non-deployed utility models .* will expire after 24 hours.*"): + utility_model = ModelFactory.create_utility_model( + name="utility_model_test", + description="utility_model_test", + code='def main(input_string:str):\n return f"Test output: {input_string}"\n', + output_examples="output_description", + ) + + # Verify initial status is DRAFT + assert utility_model.status == AssetStatus.DRAFT + + +def test_utility_model_status_after_deployment(): + """Test that model status is updated correctly after deployment""" + with requests_mock.Mocker() as mock: + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="s3://bucket/path/to/code"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="s3://bucket/path/to/code"): + # Mock the model creation + model_id = "123" + mock.post(urljoin(config.BACKEND_URL, "sdk/utilities"), json={"id": model_id}) + + # Mock the model existence check + mock.get(urljoin(config.BACKEND_URL, f"sdk/models/{model_id}"), status_code=200) + + # Create the utility model + utility_model = ModelFactory.create_utility_model( + name="utility_model_test", + description="utility_model_test", + code='def main(input_string:str):\n return f"Test output: {input_string}"\n', + output_examples="output_description", + ) + + # Verify initial status is DRAFT + assert utility_model.status == AssetStatus.DRAFT + + # Mock the model existence check and update endpoints + mock.put( + urljoin(config.BACKEND_URL, f"sdk/utilities/{model_id}"), + json={"id": model_id, "status": AssetStatus.ONBOARDED.value}, + ) + + # Deploy the model + utility_model.deploy() + + # Verify the status is updated to ONBOARDED + assert utility_model.status == AssetStatus.ONBOARDED + + # Verify no warning is shown after deployment + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + utility_model.validate() + assert len(w) == 0