diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 725fdb90..6162ed86 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -20,3 +20,4 @@ from .asset_status import AssetStatus from .index_stores import IndexStores from .function_type import FunctionType +from .code_interpeter import CodeInterpreterModel diff --git a/aixplain/enums/code_interpeter.py b/aixplain/enums/code_interpeter.py new file mode 100644 index 00000000..9f0a14c6 --- /dev/null +++ b/aixplain/enums/code_interpeter.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class CodeInterpreterModel(str, Enum): + """Code Interpreter Model IDs""" + + PYTHON_AZURE = "67476fa16eb563d00060ad62" + + def __str__(self): + return self._value_ diff --git a/aixplain/enums/function_type.py b/aixplain/enums/function_type.py index 514ff992..ae6f8e79 100644 --- a/aixplain/enums/function_type.py +++ b/aixplain/enums/function_type.py @@ -33,3 +33,4 @@ class FunctionType(Enum): SEARCH = "search" INTEGRATION = "connector" CONNECTION = "connection" + MCPSERVER = 'mcpserver' diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 53ef3e63..a00fe3bc 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -92,13 +92,6 @@ def create( # Use default GPT-4o if no LLM specified llm = get_llm_instance("669a63646eb56306647e1091", api_key=api_key) - if instructions is None: - warnings.warn( - "Use `instructions` to define the **system prompt**. " - "Use `description` to provide a **short summary** of the agent for metadata and dashboard display. " - "Note: In upcoming releases, `instructions` will become a required parameter.", - UserWarning, - ) warnings.warn( "Use `llm` to define the large language model (aixplain.modules.model.llm_model.LLM) to be used as agent. " "Use `llm_id` to provide the model ID of the large language model to be used as agent. " @@ -406,4 +399,4 @@ def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> Agent: if "message" in resp: msg = resp["message"] error_msg = f"Agent Get Error (HTTP {r.status_code}): {msg}" - raise Exception(error_msg) \ No newline at end of file + raise Exception(error_msg) diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 8264c07e..450da273 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -177,7 +177,7 @@ def build_agent(payload: Dict, tools: List[Tool] = None, api_key: Text = config. name=payload.get("name", ""), tools=payload_tools, description=payload.get("description", ""), - instructions=payload.get("role", ""), + instructions=payload.get("role"), supplier=payload.get("teamId", None), version=payload.get("version", None), cost=payload.get("cost", None), diff --git a/aixplain/factories/benchmark_factory.py b/aixplain/factories/benchmark_factory.py index c37f17a8..1c4408ea 100644 --- a/aixplain/factories/benchmark_factory.py +++ b/aixplain/factories/benchmark_factory.py @@ -22,7 +22,7 @@ """ import logging -from typing import Dict, List, Text +from typing import Dict, List, Text, Any, Tuple import json from aixplain.enums.supplier import Supplier from aixplain.modules import Dataset, Metric, Model @@ -150,9 +150,9 @@ def _validate_create_benchmark_payload(cls, payload): if len(payload["datasets"]) != 1: raise Exception("Please use exactly one dataset") if len(payload["metrics"]) == 0: - raise Exception("Please use exactly one metric") - if len(payload["model"]) == 0: - raise Exception("Please use exactly one model") + raise Exception("Please use at least one metric") + if len(payload["model"]) == 0 and payload.get("models", None) is None: + raise Exception("Please use at least one model") clean_metrics_info = {} for metric_info in payload["metrics"]: metric_id = metric_info["id"] @@ -167,6 +167,31 @@ def _validate_create_benchmark_payload(cls, payload): {"id": metric_id, "configurations": metric_config} for metric_id, metric_config in clean_metrics_info.items() ] return payload + + @classmethod + def _reformat_model_list(cls, model_list: List[Model]) -> Tuple[List[Any], List[Any]]: + """Reformat the model list to be used in the create benchmark API + + Args: + model_list (List[Model]): List of models to be used in the benchmark + + Returns: + Tuple[List[Any], List[Any]]: Reformatted model lists + + """ + model_list_without_parms, model_list_with_parms = [], [] + for model in model_list: + if "displayName" in model.additional_info: + model_list_with_parms.append({"id": model.id, "displayName": model.additional_info["displayName"], "configurations": json.dumps(model.additional_info["configuration"])}) + else: + model_list_without_parms.append(model.id) + if len(model_list_with_parms) > 0: + if len(model_list_without_parms) > 0: + raise Exception("Please provide addditional info for all models or for none of the models") + else: + model_list_with_parms = None + return model_list_without_parms, model_list_with_parms + @classmethod def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], metric_list: List[Metric]) -> Benchmark: @@ -186,15 +211,18 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], try: url = urljoin(cls.backend_url, "sdk/benchmarks") headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + model_list_without_parms, model_list_with_parms = cls._reformat_model_list(model_list) payload = { "name": name, "datasets": [dataset.id for dataset in dataset_list], - "model": [model.id for model in model_list], "metrics": [{"id": metric.id, "configurations": metric.normalization_options} for metric in metric_list], + "model": model_list_without_parms, "shapScores": [], "humanEvaluationReport": False, "automodeTraining": False, } + if model_list_with_parms is not None: + payload["models"] = model_list_with_parms clean_payload = cls._validate_create_benchmark_payload(payload) payload = json.dumps(clean_payload) r = _request_with_retry("post", url, headers=headers, data=payload) diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index 935907ab..2eacca55 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -43,6 +43,11 @@ def validate_embedding_model(model_id) -> bool: return model.function == Function.TEXT_EMBEDDING +def validate_embedding_model(model_id) -> bool: + model = ModelFactory.get(model_id) + return model.function == Function.TEXT_EMBEDDING + + class IndexFactory(ModelFactory, Generic[T]): @classmethod def create( diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index c7a2e164..2e1dc1da 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -30,8 +30,6 @@ from aixplain.factories.model_factory.mixins import ModelGetterMixin, ModelListMixin from typing import Callable, Dict, List, Optional, Text, Union - - class ModelFactory(ModelGetterMixin, ModelListMixin): """A static class for creating and exploring Model Objects. diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index dd71a0cf..ca3e1eec 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -118,7 +118,6 @@ def create_model_from_response(response: Dict) -> Model: supports_streaming=response.get("supportsStreaming", False), status=status, function_type=function_type, - **additional_kwargs, ) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 34ae74ec..e3ca6f89 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -68,7 +68,7 @@ def __init__( id: Text, name: Text, description: Text, - instructions: Text, + instructions: Optional[Text] = None, tools: List[Union[Tool, Model]] = [], llm_id: Text = "6646261c6eb563165658bbb1", llm: Optional[LLM] = None, @@ -370,7 +370,7 @@ def to_dict(self) -> Dict: "name": self.name, "assets": [build_tool_payload(tool) for tool in self.tools], "description": self.description, - "role": self.instructions, + "role": self.instructions or self.description, "supplier": (self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier), "version": self.version, "llmId": self.llm_id if self.llm is None else self.llm.id, @@ -395,30 +395,22 @@ def delete(self) -> None: "x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json", } - logging.debug( - f"Start service for DELETE Agent - {url} - {headers}" - ) + logging.debug(f"Start service for DELETE Agent - {url} - {headers}") r = _request_with_retry("delete", url, headers=headers) - logging.debug( - f"Result of request for DELETE Agent - {r.status_code}" - ) + logging.debug(f"Result of request for DELETE Agent - {r.status_code}") if r.status_code != 200: raise Exception() except Exception: try: response_json = r.json() - error_message = response_json.get('message', '').strip('{{}}') + error_message = response_json.get("message", "").strip("{{}}") if r.status_code == 403 and error_message == "err.agent_is_in_use": # Get team agents that use this agent - from aixplain.factories.team_agent_factory import ( - TeamAgentFactory - ) + from aixplain.factories.team_agent_factory import TeamAgentFactory + team_agents = TeamAgentFactory.list()["results"] - using_team_agents = [ - ta for ta in team_agents - if any(agent.id == self.id for agent in ta.agents) - ] + using_team_agents = [ta for ta in team_agents if any(agent.id == self.id for agent in ta.agents)] if using_team_agents: # Scenario 1: User has access to team agents @@ -441,15 +433,9 @@ def delete(self) -> None: "referencing it." ) else: - message = ( - f"Agent Deletion Error (HTTP {r.status_code}): " - f"{error_message}." - ) + message = f"Agent Deletion Error (HTTP {r.status_code}): " f"{error_message}." except ValueError: - message = ( - f"Agent Deletion Error (HTTP {r.status_code}): " - "There was an error in deleting the agent." - ) + message = f"Agent Deletion Error (HTTP {r.status_code}): " "There was an error in deleting the agent." logging.error(message) raise Exception(message) diff --git a/aixplain/modules/agent/tool/custom_python_code_tool.py b/aixplain/modules/agent/tool/custom_python_code_tool.py index 6433a408..05715b2a 100644 --- a/aixplain/modules/agent/tool/custom_python_code_tool.py +++ b/aixplain/modules/agent/tool/custom_python_code_tool.py @@ -25,6 +25,7 @@ from aixplain.modules.agent.tool import Tool import logging from aixplain.enums import AssetStatus +from aixplain.enums.code_interpeter import CodeInterpreterModel class CustomPythonCodeTool(Tool): @@ -37,11 +38,13 @@ def __init__( super().__init__(name=name or "", description=description, **additional_info) self.code = code self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool + self.id = CodeInterpreterModel.PYTHON_AZURE self.validate() def to_dict(self): return { + "id": self.id, "name": self.name, "description": self.description, "type": "utility", diff --git a/aixplain/modules/agent/utils.py b/aixplain/modules/agent/utils.py index aba5bb1c..684c82db 100644 --- a/aixplain/modules/agent/utils.py +++ b/aixplain/modules/agent/utils.py @@ -2,7 +2,9 @@ import re -def process_variables(query: Union[Text, Dict], data: Union[Dict, Text], parameters: Dict, agent_description: Text) -> Text: +def process_variables( + query: Union[Text, Dict], data: Union[Dict, Text], parameters: Dict, agent_description: Union[Text, None] +) -> Text: from aixplain.factories.file_factory import FileFactory if isinstance(query, dict): @@ -13,7 +15,7 @@ def process_variables(query: Union[Text, Dict], data: Union[Dict, Text], paramet else: input_data = {"input": FileFactory.to_link(query)} - variables = re.findall(r"(? None: message = "Model Deletion Error: Make sure the model exists and you are the owner." logging.error(message) raise Exception(f"{message}") + + def add_additional_info_for_benchmark(self, display_name: str, configuration: Dict) -> None: + """Add additional info for benchmark + + Args: + display_name (str): display name of the model + configuration (Dict): configuration of the model + """ + self.additional_info["displayName"] = display_name + self.additional_info["configuration"] = configuration @classmethod def from_dict(cls, data: Dict) -> "Model": @@ -451,3 +461,4 @@ def from_dict(cls, data: Dict) -> "Model": model_params=data.get("model_params"), **data.get("additional_info", {}), ) + diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index 72055d69..a240f3b2 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -7,7 +7,6 @@ from enum import Enum from typing import List from aixplain.enums.splitting_options import SplittingOptions - import os from urllib.parse import urljoin @@ -56,8 +55,6 @@ def __init__( self.split_length = split_length self.split_overlap = split_overlap - - class IndexModel(Model): def __init__( self, @@ -125,7 +122,6 @@ def to_dict(self) -> Dict: data["collection_type"] = self.version.split("-", 1)[0] return data - def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: """Search for documents in the index diff --git a/aixplain/modules/pipeline/default.py b/aixplain/modules/pipeline/default.py index 2fa4e859..60d879f4 100644 --- a/aixplain/modules/pipeline/default.py +++ b/aixplain/modules/pipeline/default.py @@ -15,5 +15,4 @@ def save(self, *args, **kwargs): def to_dict(self) -> dict: return self.serialize() - - + \ No newline at end of file diff --git a/aixplain/modules/pipeline/designer/enums.py b/aixplain/modules/pipeline/designer/enums.py index fe4cbfed..b733265d 100644 --- a/aixplain/modules/pipeline/designer/enums.py +++ b/aixplain/modules/pipeline/designer/enums.py @@ -1,5 +1,5 @@ from enum import Enum - +from aixplain.enums import FunctionType class RouteType(str, Enum): CHECK_TYPE = "checkType" @@ -29,15 +29,6 @@ class NodeType(str, Enum): class AssetType(str, Enum): MODEL = "MODEL" - -class FunctionType(str, Enum): - AI = "ai" - SEGMENTOR = "segmentor" - RECONSTRUCTOR = "reconstructor" - UTILITY = "utility" - METRIC = "metric" - - class ParamType: INPUT = "INPUT" OUTPUT = "OUTPUT" diff --git a/aixplain/modules/pipeline/designer/nodes.py b/aixplain/modules/pipeline/designer/nodes.py index e81be81a..e92d5fe1 100644 --- a/aixplain/modules/pipeline/designer/nodes.py +++ b/aixplain/modules/pipeline/designer/nodes.py @@ -1,4 +1,5 @@ from typing import List, Union, Type, TYPE_CHECKING, Optional +from enum import Enum from aixplain.modules import Model from aixplain.enums import DataType, Function @@ -142,7 +143,11 @@ def serialize(self) -> dict: obj["supplier"] = self.supplier obj["version"] = self.version obj["assetType"] = self.assetType - obj["functionType"] = self.functionType + # Handle functionType as enum or string + if isinstance(self.functionType, Enum): + obj["functionType"] = self.functionType.value + else: + obj["functionType"] = self.functionType obj["type"] = self.type return obj diff --git a/aixplain/utils/asset_cache.py b/aixplain/utils/asset_cache.py index 8a693c26..357b70ef 100644 --- a/aixplain/utils/asset_cache.py +++ b/aixplain/utils/asset_cache.py @@ -20,7 +20,6 @@ CACHE_DURATION = 86400 - @dataclass class Store(Generic[T]): data: Dict[str, T] @@ -62,7 +61,6 @@ def compute_expiry(self): del os.environ["CACHE_EXPIRY_TIME"] expiry = CACHE_DURATION - return time.time() + int(expiry) def invalidate(self): diff --git a/aixplain/utils/cache_utils.py b/aixplain/utils/cache_utils.py index 464e916a..fcfe1cb6 100644 --- a/aixplain/utils/cache_utils.py +++ b/aixplain/utils/cache_utils.py @@ -14,7 +14,6 @@ def get_cache_expiry(): return int(os.getenv("CACHE_EXPIRY_TIME", CACHE_DURATION)) - def save_to_cache(cache_file, data, lock_file): try: os.makedirs(os.path.dirname(cache_file), exist_ok=True) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 08e6e095..2783ffc3 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -18,6 +18,7 @@ import copy import json import os +import re from dotenv import load_dotenv load_dotenv() @@ -234,7 +235,10 @@ def test_delete_agent_in_use(delete_agents_and_team_agents, AgentFactory): with pytest.raises(Exception) as exc_info: agent.delete() - assert str(exc_info.value) == "Agent Deletion Error (HTTP 403): err.agent_is_in_use." + assert re.match( + r"Error: Agent cannot be deleted\.\nReason: This agent is currently used by one or more team agents\.\n\nteam_agent_id: [a-f0-9]{24}\. To proceed, remove the agent from all team agents before deletion\.", + str(exc_info.value), + ) @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) @@ -302,7 +306,7 @@ def test_update_tools_of_agent(run_input_map, delete_agents_and_team_agents, Age "type": "translation", "supplier": "Microsoft", "function": "translation", - "query": "Translate: Olá, como vai você?", + "query": "Translate: 'Olá, como vai você?'", "description": "Translation tool with target language", "expected_tool_input": "targetlanguage", }, @@ -334,7 +338,7 @@ def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agent # Create and run agent agent = AgentFactory.create( name="Test Parameter Agent", - description="Test agent with parameterized tools. You MUST use a tool for the tasks.", + description="Test agent with parameterized tools. You MUST use a tool for the tasks. Do not directly answer the question.", tools=[tool], llm_id="6646261c6eb563165658bbb1", # Using LLM ID from test data ) @@ -351,8 +355,9 @@ def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agent # Verify tool was used in execution assert len(response["data"]["intermediate_steps"]) > 0 tool_used = False + for step in response["data"]["intermediate_steps"]: - if tool_config["expected_tool_input"] in step["tool_steps"][0]["input"]: + if len(step["tool_steps"]) > 0 and tool_config["expected_tool_input"] in step["tool_steps"][0]["input"]: tool_used = True break assert tool_used, "Tool was not used in execution" @@ -643,6 +648,7 @@ def test_agent_llm_parameter_preservation(delete_agents_and_team_agents, AgentFa # Reset the LLM temperature to its original value llm.temperature = original_temperature + def test_run_agent_with_expected_output(): from pydantic import BaseModel from typing import Optional, List @@ -753,4 +759,3 @@ def test_agent_with_action_tool(): assert "helsinki" in response.data.output.lower() assert "SLACK_CHAT_POST_MESSAGE" in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] connection.delete() - diff --git a/tests/functional/benchmark/benchmark_functional_test.py b/tests/functional/benchmark/benchmark_functional_test.py index 93abd869..7c691a6f 100644 --- a/tests/functional/benchmark/benchmark_functional_test.py +++ b/tests/functional/benchmark/benchmark_functional_test.py @@ -11,9 +11,7 @@ from pathlib import Path import pytest - import logging - from aixplain import aixplain_v2 as v2 logger = logging.getLogger() @@ -22,6 +20,7 @@ TIMEOUT = 60 * 30 RUN_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_test_run_data.json")) MODULE_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_module_test_data.json")) +RUN_WITH_PARAMETERS_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_test_with_parameters.json")) def read_data(data_path): @@ -33,6 +32,11 @@ def run_input_map(request): return request.param +@pytest.fixture(scope="module", params=[(name, params) for name, params in read_data(RUN_WITH_PARAMETERS_FILE).items()]) +def run_with_parameters_input_map(request): + return request.param + + @pytest.fixture(scope="module", params=read_data(MODULE_FILE)) def module_input_map(request): return request.param @@ -79,12 +83,22 @@ def test_create_and_run(run_input_map, BenchmarkFactory): assert_correct_results(benchmark_job) -# def test_module(module_input_map): -# benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"]) -# assert benchmark.id == module_input_map["benchmark_id"] -# benchmark_job = benchmark.job_list[0] -# assert benchmark_job.benchmark_id == module_input_map["benchmark_id"] -# job_status = benchmark_job.check_status() -# assert job_status in ["in_progress", "completed"] -# df = benchmark_job.download_results_as_csv(return_dataframe=True) -# assert type(df) is pd.DataFrame +@pytest.mark.parametrize("BenchmarkFactory", [BenchmarkFactory, v2.Benchmark]) +def test_create_and_run_with_parameters(run_with_parameters_input_map, BenchmarkFactory): + name, params = run_with_parameters_input_map + model_list = [] + for model_info in params["models_with_parameters"]: + model = ModelFactory.get(model_info["model_id"]) + model.add_additional_info_for_benchmark(display_name=model_info["display_name"], configuration=model_info["configuration"]) + model_list.append(model) + dataset_list = [DatasetFactory.list(query=dataset_name)["results"][0] for dataset_name in params["dataset_names"]] + metric_list = [MetricFactory.get(metric_id) for metric_id in params["metric_ids"]] + benchmark = BenchmarkFactory.create(f"SDK Benchmark Test With Parameters({name}) {uuid.uuid4()}", dataset_list, model_list, metric_list) + assert type(benchmark) is Benchmark, "Couldn't create benchmark" + benchmark_job = benchmark.start() + assert type(benchmark_job) is BenchmarkJob, "Couldn't start job" + assert is_job_finshed(benchmark_job), "Job did not finish in time" + assert_correct_results(benchmark_job) + + + diff --git a/tests/functional/benchmark/data/benchmark_test_with_parameters.json b/tests/functional/benchmark/data/benchmark_test_with_parameters.json new file mode 100644 index 00000000..287d3d9f --- /dev/null +++ b/tests/functional/benchmark/data/benchmark_test_with_parameters.json @@ -0,0 +1,22 @@ +{ + "Translation With LLMs": { + "models_with_parameters": [ + { + "model_id": "669a63646eb56306647e1091", + "display_name": "EnHi LLM", + "configuration": { + "prompt": "Translate the following text into Hindi." + } + }, + { + "model_id": "669a63646eb56306647e1091", + "display_name": "EnEs LLM", + "configuration": { + "prompt": "Translate the following text into Spanish." + } + } + ], + "dataset_names": ["EnHi SDK Test - Benchmark Dataset"], + "metric_ids": ["639874ab506c987b1ae1acc6", "6408942f166427039206d71e"] + } +} \ No newline at end of file diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 71838034..ab210109 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -106,7 +106,6 @@ def run_index_model(index_model, retries): pytest.param(None, ZeroEntropyParams, id="ZERO_ENTROPY"), pytest.param(EmbeddingModel.OPENAI_ADA002, GraphRAGParams, id="GRAPHRAG"), pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="AIR - OpenAI Ada 002"), - pytest.param("6658d40729985c2cf72f42ec", AirParams, id="AIR - Snowflake Arctic Embed M Long"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), pytest.param("67efd4f92a0a850afa045af7", AirParams, id="AIR - BGE M3"), ], @@ -133,7 +132,6 @@ def test_index_model(embedding_model, supplier_params): [ pytest.param(None, VectaraParams, id="VECTARA"), pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="OpenAI Ada 002"), - pytest.param("6658d40729985c2cf72f42ec", AirParams, id="Snowflake Arctic Embed M Long"), pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, AirParams, id="Jina Clip v2 Multimodal"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), pytest.param("67efd4f92a0a850afa045af7", AirParams, id="BGE M3"), diff --git a/tests/functional/pipelines/create_test.py b/tests/functional/pipelines/create_test.py index f1dac2c4..076a637f 100644 --- a/tests/functional/pipelines/create_test.py +++ b/tests/functional/pipelines/create_test.py @@ -16,7 +16,6 @@ limitations under the License. """ import os -from aixplain.utils.cache_utils import CACHE_FOLDER from aixplain.modules.pipeline import Pipeline import json import pytest @@ -79,25 +78,3 @@ def test_create_pipeline_wrong_path(PipelineFactory): with pytest.raises(Exception): PipelineFactory.create(name=pipeline_name, pipeline="/") - - -@pytest.mark.parametrize("PipelineFactory", [PipelineFactory]) -def test_pipeline_cache_creation(PipelineFactory): - cache_file = os.path.join(CACHE_FOLDER, "pipelines.json") - if os.path.exists(cache_file): - os.remove(cache_file) - - pipeline_json = "tests/functional/pipelines/data/pipeline.json" - pipeline_name = str(uuid4()) - pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_json) - - assert os.path.exists(cache_file), "Pipeline cache file was not created!" - - with open(cache_file, "r") as f: - cache_data = json.load(f) - - assert "data" in cache_data, "Cache format invalid, missing 'data'." - - pipeline.delete() - if os.path.exists(cache_file): - os.remove(cache_file) \ No newline at end of file diff --git a/tests/functional/pipelines/run_test.py b/tests/functional/pipelines/run_test.py index a3fcbfb1..999bc2d0 100644 --- a/tests/functional/pipelines/run_test.py +++ b/tests/functional/pipelines/run_test.py @@ -328,3 +328,32 @@ def test_run_failure(version: str, PipelineFactory): ) assert response["status"] == ResponseStatus.FAILED + + +@pytest.mark.parametrize("version", ["2.0", "3.0"]) +@pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) +def test_run_async_simple(version: str, PipelineFactory): + """Test simple async pipeline execution with polling""" + pipeline = PipelineFactory.list(query="SingleNodePipeline")["results"][0] + + # Start async execution + response = pipeline.run_async( + data="Translate this simple text", + **{"version": version} + ) + + poll_url = response["url"] + import time + max_attempts = 30 + attempt = 0 + + while attempt < max_attempts: + poll_response = pipeline.poll(poll_url) + if hasattr(poll_response, 'completed') and poll_response.completed: + break + elif isinstance(poll_response, dict) and poll_response.get("completed", False): + break + time.sleep(1) + attempt += 1 + + assert poll_response.status == ResponseStatus.SUCCESS diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index ce520a69..a93c35a8 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -18,21 +18,21 @@ def test_fail_no_data_query(): - agent = Agent("123", "Test Agent(-)", "Sample Description", "Test Agent Role") + agent = Agent("123", "Test Agent(-)", "Sample Description", instructions="Test Agent Role") with pytest.raises(Exception) as exc_info: agent.run_async() assert str(exc_info.value) == "Either 'data' or 'query' must be provided." def test_fail_query_must_be_provided(): - agent = Agent("123", "Test Agent", "Sample Description", "Test Agent Role") + agent = Agent("123", "Test Agent", "Sample Description", instructions="Test Agent Role") with pytest.raises(Exception) as exc_info: agent.run_async(data={}) assert str(exc_info.value) == "When providing a dictionary, 'query' must be provided." def test_fail_query_as_text_when_content_not_empty(): - agent = Agent("123", "Test Agent", "Sample Description", "Test Agent Role") + agent = Agent("123", "Test Agent", "Sample Description", instructions="Test Agent Role") with pytest.raises(Exception) as exc_info: agent.run_async( data={"query": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"}, @@ -45,7 +45,7 @@ def test_fail_query_as_text_when_content_not_empty(): def test_fail_content_exceed_maximum(): - agent = Agent("123", "Test Agent", "Sample Description", "Test Agent Role") + agent = Agent("123", "Test Agent", "Sample Description", instructions="Test Agent Role") with pytest.raises(Exception) as exc_info: agent.run_async( data={"query": "Transcribe the audios:"}, @@ -60,14 +60,14 @@ def test_fail_content_exceed_maximum(): def test_fail_key_not_found(): - agent = Agent("123", "Test Agent", "Sample Description", "Test Agent Role") + agent = Agent("123", "Test Agent", "Sample Description", instructions="Test Agent Role") with pytest.raises(Exception) as exc_info: agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input2": "Hello, how are you?"}) assert str(exc_info.value) == "Key 'input2' not found in query." def test_success_query_content(): - agent = Agent("123", "Test Agent(-)", "Sample Description", "Test Agent Role") + agent = Agent("123", "Test Agent(-)", "Sample Description", instructions="Test Agent Role") with requests_mock.Mocker() as mock: url = agent.url headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} @@ -388,7 +388,7 @@ def test_save_success(mock_model_factory_get): def test_run_success(): - agent = Agent("123", "Test Agent(-)", "Sample Description", "Test Agent Role") + agent = Agent("123", "Test Agent(-)", "Sample Description", instructions="Test Agent Role") url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") agent.url = url with requests_mock.Mocker() as mock: @@ -406,7 +406,7 @@ def test_run_success(): def test_run_variable_error(): - agent = Agent("123", "Test Agent", "Agent description", "Translate the input data into {target_language}") + agent = Agent("123", "Test Agent", "Agent description", instructions="Translate the input data into {target_language}") with pytest.raises(Exception) as exc_info: agent.run_async(data={"query": "Hello, how are you?"}, output_format=OutputFormat.MARKDOWN) assert str(exc_info.value) == ( @@ -462,6 +462,193 @@ def test_agent_default_api_key(): assert agent.tools[0].api_key == config.TEAM_API_KEY +def test_agent_optional_instructions(): + """Test that Agent can be created with optional instructions""" + agent = Agent(id="123", name="Test Agent", description="Test Description") + + # Check that the agent was created successfully + assert agent.id == "123" + assert agent.name == "Test Agent" + assert agent.description == "Test Description" + assert agent.instructions is None + + +def test_agent_factory_create_without_instructions(): + """Test AgentFactory.create() payload when no instructions are provided""" + from aixplain.factories import AgentFactory + from unittest.mock import patch + import requests_mock + from urllib.parse import urljoin + from aixplain.utils import config + + with patch("aixplain.factories.model_factory.ModelFactory.get") as mock_model_factory_get: + from aixplain.enums import Function + from aixplain.modules.model import Model + + # Mock the LLM model + mock_model = Model( + id="6646261c6eb563165658bbb1", + name="Test LLM", + description="Test LLM Description", + function=Function.TEXT_GENERATION, + ) + mock_model_factory_get.return_value = mock_model + + with requests_mock.Mocker() as mock: + url = urljoin(config.BACKEND_URL, "sdk/agents") + headers = {"x-api-key": config.TEAM_API_KEY} + + # Mock response from server + ref_response = { + "id": "123", + "name": "Test Agent", + "description": "Test Agent Description", + "role": "Test Agent Description", # Should fallback to description + "teamId": "123", + "version": "1.0", + "status": "draft", + "llmId": "6646261c6eb563165658bbb1", + "assets": [], + } + mock.post(url, headers=headers, json=ref_response) + + # Mock LLM GET request + url = urljoin(config.BACKEND_URL, "sdk/models/6646261c6eb563165658bbb1") + model_ref_response = { + "id": "6646261c6eb563165658bbb1", + "name": "Test LLM", + "description": "Test LLM Description", + "function": {"id": "text-generation"}, + "supplier": "openai", + "version": {"id": "1.0"}, + "status": "onboarded", + "pricing": {"currency": "USD", "value": 0.0}, + } + mock.get(url, headers=headers, json=model_ref_response) + + # Create agent without instructions + agent = AgentFactory.create( + name="Test Agent", + description="Test Agent Description", + # No instructions parameter + llm_id="6646261c6eb563165658bbb1", + ) + + # Verify the agent was created with fallback instructions + assert agent.instructions == "Test Agent Description" # Should fallback to description + assert agent.name == "Test Agent" + assert agent.description == "Test Agent Description" + + # Check the request payload that was sent + sent_request = mock.request_history[0] + sent_payload = sent_request.json() + + # The role should be set to description when instructions is None + assert sent_payload["role"] == "Test Agent Description" + assert sent_payload["description"] == "Test Agent Description" + + +def test_agent_to_dict_payload_without_instructions(): + """Test Agent.to_dict() payload when instructions is None""" + # Create agent with no instructions + agent = Agent(id="123", name="Test Agent", description="Test Description") + + # Get the payload + payload = agent.to_dict() + + # Check that role falls back to description when instructions is None + assert payload["role"] == "Test Description" # Should fallback to description + assert payload["description"] == "Test Description" + assert agent.instructions is None + + +def test_agent_to_dict_payload_with_instructions(): + """Test Agent.to_dict() payload when instructions is provided""" + # Create agent with instructions + agent = Agent(id="123", name="Test Agent", description="Test Description", instructions="Custom Instructions") + + # Get the payload + payload = agent.to_dict() + + # Check that role uses instructions when provided + assert payload["role"] == "Custom Instructions" + assert payload["description"] == "Test Description" + assert agent.instructions == "Custom Instructions" + + +def test_agent_factory_create_with_explicit_none_instructions(): + """Test AgentFactory.create() payload when instructions=None is explicitly passed""" + from aixplain.factories import AgentFactory + from unittest.mock import patch + import requests_mock + from urllib.parse import urljoin + from aixplain.utils import config + + with patch("aixplain.factories.model_factory.ModelFactory.get") as mock_model_factory_get: + from aixplain.enums import Function + from aixplain.modules.model import Model + + # Mock the LLM model + mock_model = Model( + id="6646261c6eb563165658bbb1", + name="Test LLM", + description="Test LLM Description", + function=Function.TEXT_GENERATION, + ) + mock_model_factory_get.return_value = mock_model + + with requests_mock.Mocker() as mock: + url = urljoin(config.BACKEND_URL, "sdk/agents") + headers = {"x-api-key": config.TEAM_API_KEY} + + # Mock response from server + ref_response = { + "id": "123", + "name": "Test Agent", + "description": "Test Agent Description", + "role": "Test Agent Description", # Should fallback to description + "teamId": "123", + "version": "1.0", + "status": "draft", + "llmId": "6646261c6eb563165658bbb1", + "assets": [], + } + mock.post(url, headers=headers, json=ref_response) + + # Mock LLM GET request + url = urljoin(config.BACKEND_URL, "sdk/models/6646261c6eb563165658bbb1") + model_ref_response = { + "id": "6646261c6eb563165658bbb1", + "name": "Test LLM", + "description": "Test LLM Description", + "function": {"id": "text-generation"}, + "supplier": "openai", + "version": {"id": "1.0"}, + "status": "onboarded", + "pricing": {"currency": "USD", "value": 0.0}, + } + mock.get(url, headers=headers, json=model_ref_response) + + # Create agent with explicit instructions=None + agent = AgentFactory.create( + name="Test Agent", + description="Test Agent Description", + instructions=None, # Explicitly set to None + llm_id="6646261c6eb563165658bbb1", + ) + + # Verify the agent was created with fallback instructions + assert agent.instructions == "Test Agent Description" # Should fallback to description + + # Check the request payload that was sent + sent_request = mock.request_history[0] + sent_payload = sent_request.json() + + # The role should be set to description when instructions is None + assert sent_payload["role"] == "Test Agent Description" + assert sent_payload["description"] == "Test Agent Description" + + def test_agent_multiple_tools_api_key(): """Test that api_key is properly propagated to multiple tools""" custom_api_key = "custom_test_key" @@ -634,7 +821,6 @@ def test_custom_python_code_tool_validation_missing_code(): assert str(exc_info.value) == "Custom Python Code Tool Error: Code is required" - @patch("aixplain.factories.model_factory.ModelFactory.get") def test_create_agent_with_model_instance(mock_model_factory_get): from aixplain.enums import Supplier, Function diff --git a/tests/unit/agent/sql_tool_test.py b/tests/unit/agent/sql_tool_test.py index d68cd634..67faaf47 100644 --- a/tests/unit/agent/sql_tool_test.py +++ b/tests/unit/agent/sql_tool_test.py @@ -310,6 +310,8 @@ def test_create_sql_tool_from_csv(tmp_path, mocker): # Clean up the database file if os.path.exists(tool.database): os.remove(tool.database) + if os.path.exists("test.db"): + os.remove("test.db") def test_sql_tool_schema_inference(tmp_path): @@ -357,4 +359,4 @@ def test_create_sql_tool_source_type_handling(tmp_path): with pytest.raises(SQLToolError, match="Source type must be either a string or DatabaseSourceType enum, got "): AgentFactory.create_sql_tool( name="Test SQL", description="Test", source=db_path, source_type=123, schema="test" - ) # Invalid type \ No newline at end of file + ) # Invalid type